1 /*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18 #include "RenderScript.h"
19 #include "rsCppInternal.h"
20
21 using namespace android;
22 using namespace RSC;
23
24 // ScriptIntrinsicBLAS APIS
ScriptIntrinsicBLAS(sp<RS> rs,sp<const Element> e)25 ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
26 : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
27
28 }
29
create(sp<RS> rs)30 sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(sp<RS> rs) {
31 return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
32 }
33
34 enum RsBlasDataType {
35 SINGLE,
36 DOUBLE,
37 SINGLE_COMPLEX,
38 DOUBLE_COMPLEX
39 };
40
41 static RsBlasCall
setUpBLASCall(RsBlasDataType dataType,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,int incX,int incY,int KL,int KU,float alphaF,float betaF,double alphaD,double betaD,float alphaCX,float alphaCY,float betaCX,float betaCY,double alphaZX,double alphaZY,double betaZX,double betaZY)42 setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
43 int TransA, int TransB, int Side, int Uplo, int Diag,
44 int M, int N, int K, int incX, int incY, int KL, int KU,
45 float alphaF, float betaF, double alphaD, double betaD,
46 float alphaCX, float alphaCY, float betaCX, float betaCY,
47 double alphaZX, double alphaZY, double betaZX, double betaZY
48 ) {
49 RsBlasCall call;
50 memset(&call, 0, sizeof(call));
51 call.func = func;
52 call.transA = (RsBlasTranspose)TransA;
53 call.transB = (RsBlasTranspose)TransB;
54 call.side = (RsBlasSide)Side;
55 call.uplo = (RsBlasUplo)Uplo;
56 call.diag = (RsBlasDiag)Diag;
57 call.M = M;
58 call.N = N;
59 call.K = K;
60
61 switch (dataType) {
62 case SINGLE:
63 // For Single-precision BLAS.
64 call.alpha.f = alphaF;
65 call.beta.f = betaF;
66 break;
67 case DOUBLE:
68 // For Double-precision BLAS.
69 call.alpha.d = alphaD;
70 call.beta.d = betaD;
71 break;
72 case SINGLE_COMPLEX:
73 // For Single-precision complex BLAS.
74 call.alpha.c.r = alphaCX;
75 call.alpha.c.i = alphaCY;
76 call.beta.c.r = betaCX;
77 call.beta.c.i = betaCY;
78 break;
79 case DOUBLE_COMPLEX:
80 // For Double-precision complex BLAS.
81 call.alpha.z.r = alphaZX;
82 call.alpha.z.i = alphaZY;
83 call.beta.z.r = betaZX;
84 call.beta.z.i = betaZY;
85 break;
86 default:
87 break;
88 }
89
90 call.incX = incX;
91 call.incY = incY;
92 call.KL = KL;
93 call.KU = KU;
94
95 return call;
96 }
97
98 static void
nScriptIntrinsicBLAS_Single(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alpha,RsAllocation A,RsAllocation B,float beta,RsAllocation C,int incX,int incY,int KL,int KU)99 nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
100 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
101 float alpha, RsAllocation A, RsAllocation B,
102 float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
103 RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
104 M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
105 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
106 RsAllocation in_allocs[3] = {A, B, C};
107 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
108 &call, sizeof(call), nullptr, 0));
109 }
110
111
112 static void
nScriptIntrinsicBLAS_Double(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alpha,RsAllocation A,RsAllocation B,double beta,RsAllocation C,int incX,int incY,int KL,int KU)113 nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
114 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
115 double alpha, RsAllocation A, RsAllocation B,
116 double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
117 RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
118 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
119 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
120 RsAllocation in_allocs[3] = {A, B, C};
121 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
122 &call, sizeof(call), nullptr, 0));
123 }
124
125 static void
nScriptIntrinsicBLAS_Complex(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alphaX,float alphaY,RsAllocation A,RsAllocation B,float betaX,float betaY,RsAllocation C,int incX,int incY,int KL,int KU)126 nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
127 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
128 float alphaX, float alphaY, RsAllocation A, RsAllocation B,
129 float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
130 RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
131 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
132 alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
133 RsAllocation in_allocs[3] = {A, B, C};
134 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
135 &call, sizeof(call), nullptr, 0));
136 }
137
138 static void
nScriptIntrinsicBLAS_Z(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alphaX,double alphaY,RsAllocation A,RsAllocation B,double betaX,double betaY,RsAllocation C,int incX,int incY,int KL,int KU)139 nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
140 int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
141 double alphaX, double alphaY, RsAllocation A, RsAllocation B,
142 double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
143 RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
144 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
145 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
146 RsAllocation in_allocs[3] = {A, B, C};
147 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
148 &call, sizeof(call), nullptr, 0));
149 }
150
151
152 static void
nScriptIntrinsicBLAS_BNNM(RS * mRS,RsContext con,RsScript id,int M,int N,int K,RsAllocation A,int a_offset,RsAllocation B,int b_offset,RsAllocation C,int c_offset,int c_mult_int)153 nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
154 RsAllocation A, int a_offset, RsAllocation B, int b_offset,
155 RsAllocation C, int c_offset, int c_mult_int) {
156 RsBlasCall call;
157 memset(&call, 0, sizeof(call));
158 call.func = RsBlas_bnnm;
159 call.M = M;
160 call.N = N;
161 call.K = K;
162 call.a_offset = a_offset & 0xFF;
163 call.b_offset = b_offset & 0xFF;
164 call.c_offset = c_offset;
165 call.c_mult_int = c_mult_int;
166
167 RsAllocation in_allocs[3] = {A, B, C};
168 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
169 &call, sizeof(call), nullptr, 0));
170 }
171
172 /**
173 * Level 2 BLAS
174 */
validateGEMV(RS * mRS,sp<const Element> e,RsBlasTranspose TransA,sp<Allocation> A,sp<Allocation> X,int incX,sp<Allocation> Y,int incY)175 static void validateGEMV(RS* mRS, sp<const Element> e, RsBlasTranspose TransA, sp<Allocation> A,
176 sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
177 int M = A->getType()->getY();
178 int N = A->getType()->getX();
179 if (!A->getType()->getElement()->isCompatible(e) ||
180 !X->getType()->getElement()->isCompatible(e) ||
181 !Y->getType()->getElement()->isCompatible(e)) {
182 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
183 }
184 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
185 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
186 }
187
188 if (incX <= 0 || incY <= 0) {
189 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
190 }
191 int expectedXDim = -1, expectedYDim = -1;
192 if (TransA == RsBlasNoTrans) {
193 expectedXDim = 1 + (N - 1) * incX;
194 expectedYDim = 1 + (M - 1) * incY;
195 } else {
196 expectedXDim = 1 + (M - 1) * incX;
197 expectedYDim = 1 + (N - 1) * incY;
198 }
199 if ((int)X->getType()->getX() != expectedXDim ||
200 (int)Y->getType()->getX() != expectedYDim) {
201 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
202 }
203 }
204
SGEMV(RsBlasTranspose TransA,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)205 void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp<Allocation> A, sp<Allocation> X,
206 int incX, float beta, sp<Allocation> Y, int incY) {
207 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
208 int M = A->getType()->getY();
209 int N = A->getType()->getX();
210 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
211 TransA, 0, 0, 0, 0, M, N, 0,
212 alpha, A->getID(), X->getID(),
213 beta, Y->getID(), incX, incY, 0, 0);
214 }
215
DGEMV(RsBlasTranspose TransA,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)216 void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp<Allocation> A, sp<Allocation> X,
217 int incX, double beta, sp<Allocation> Y, int incY) {
218 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
219 int M = A->getType()->getY();
220 int N = A->getType()->getX();
221 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
222 TransA, 0, 0, 0, 0, M, N, 0,
223 alpha, A->getID(), X->getID(),
224 beta, Y->getID(), incX, incY, 0, 0);
225 }
226
CGEMV(RsBlasTranspose TransA,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)227 void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp<Allocation> A, sp<Allocation> X,
228 int incX, Float2 beta, sp<Allocation> Y, int incY) {
229 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
230 int M = A->getType()->getY();
231 int N = A->getType()->getX();
232 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
233 TransA, 0, 0, 0, 0, M, N, 0,
234 alpha.x, alpha.y, A->getID(), X->getID(),
235 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
236 }
237
ZGEMV(RsBlasTranspose TransA,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)238 void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
239 int incX, Double2 beta, sp<Allocation> Y, int incY) {
240 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
241 int M = A->getType()->getY();
242 int N = A->getType()->getX();
243 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
244 TransA, 0, 0, 0, 0, M, N, 0,
245 alpha.x, alpha.y, A->getID(), X->getID(),
246 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
247 }
248
SGBMV(RsBlasTranspose TransA,int KL,int KU,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)249 void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp<Allocation> A,
250 sp<Allocation> X, int incX, float beta, sp<Allocation> Y, int incY) {
251 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
252 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
253 if (KL < 0 || KU < 0) {
254 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
255 }
256 int M = A->getType()->getY();
257 int N = A->getType()->getX();
258
259 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
260 TransA, 0, 0, 0, 0, M, N, 0,
261 alpha, A->getID(), X->getID(),
262 beta, Y->getID(), incX, incY, KL, KU);
263 }
264
DGBMV(RsBlasTranspose TransA,int KL,int KU,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)265 void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp<Allocation> A,
266 sp<Allocation> X, int incX, double beta, sp<Allocation> Y, int incY) {
267 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
268 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
269 if (KL < 0 || KU < 0) {
270 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
271 }
272 int M = A->getType()->getY();
273 int N = A->getType()->getX();
274
275 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
276 TransA, 0, 0, 0, 0, M, N, 0,
277 alpha, A->getID(), X->getID(),
278 beta, Y->getID(), incX, incY, KL, KU);
279 }
280
CGBMV(RsBlasTranspose TransA,int KL,int KU,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)281 void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp<Allocation> A,
282 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
283 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
284 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
285 if (KL < 0 || KU < 0) {
286 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
287 }
288 int M = A->getType()->getY();
289 int N = A->getType()->getX();
290
291 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
292 TransA, 0, 0, 0, 0, M, N, 0,
293 alpha.x, alpha.y, A->getID(), X->getID(),
294 beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
295 }
296
ZGBMV(RsBlasTranspose TransA,int KL,int KU,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)297 void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp<Allocation> A,
298 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
299 // GBMV has the same validation requirements as GEMV + KL and KU >= 0
300 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
301 if (KL < 0 || KU < 0) {
302 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
303 }
304 int M = A->getType()->getY();
305 int N = A->getType()->getX();
306
307 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
308 TransA, 0, 0, 0, 0, M, N, 0,
309 alpha.x, alpha.y, A->getID(), X->getID(),
310 beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
311 }
312
validateTRMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)313 static void validateTRMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA,
314 RsBlasDiag Diag, sp<Allocation> A, sp<Allocation> X, int incX) {
315 int N = A->getType()->getY();
316 if ((int)A->getType()->getX() != N) {
317 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
318 }
319 if (!A->getType()->getElement()->isCompatible(e) ||
320 !X->getType()->getElement()->isCompatible(e)) {
321 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
322 }
323 if (X->getType()->getY() > 1) {
324 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
325 }
326
327 if (incX <= 0) {
328 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
329 }
330 int expectedXDim = 1 + (N - 1) * incX;
331 if ((int)X->getType()->getX() != expectedXDim) {
332 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
333 }
334 }
335
validateTPMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)336 static int validateTPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA,
337 RsBlasDiag Diag, sp<Allocation> Ap, sp<Allocation> X, int incX) {
338 if (!Ap->getType()->getElement()->isCompatible(e) ||
339 !X->getType()->getElement()->isCompatible(e)) {
340 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
341 }
342 if (X->getType()->getY() > 1) {
343 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
344 }
345
346 if (Ap->getType()->getY() > 1) {
347 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
348 }
349
350 int N = sqrt((double)Ap->getType()->getX() * 2);
351 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
352 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
353 }
354 if (incX <= 0) {
355 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
356 }
357 int expectedXDim = 1 + (N - 1) * incX;
358 if ((int)X->getType()->getX() != expectedXDim) {
359 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
360 }
361
362 return N;
363 }
364
365
STRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)366 void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
367 sp<Allocation> A, sp<Allocation> X, int incX) {
368 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
369 int N = A->getType()->getY();
370 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
371 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
372 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
373 }
374
DTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)375 void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
376 sp<Allocation> A, sp<Allocation> X, int incX) {
377 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
378 int N = A->getType()->getY();
379 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
380 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
381 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
382 }
383
CTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)384 void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
385 sp<Allocation> A, sp<Allocation> X, int incX) {
386 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
387 int N = A->getType()->getY();
388 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
389 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
390 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
391 }
392
ZTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)393 void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
394 sp<Allocation> A, sp<Allocation> X, int incX) {
395 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
396 int N = A->getType()->getY();
397 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
398 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
399 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
400 }
401
STBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)402 void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
403 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
404 // TBMV has the same requirements as TRMV + K >= 0
405 if (K < 0) {
406 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
407 }
408 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
409 int N = A->getType()->getY();
410 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
411 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
412 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
413 }
414
DTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)415 void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
416 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
417 // TBMV has the same requirements as TRMV + K >= 0
418 if (K < 0) {
419 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
420 }
421 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
422 int N = A->getType()->getY();
423 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
424 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
425 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
426 }
427
CTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)428 void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
429 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
430 // TBMV has the same requirements as TRMV + K >= 0
431 if (K < 0) {
432 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
433 }
434 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
435 int N = A->getType()->getY();
436 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
437 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
438 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
439 }
440
ZTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)441 void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
442 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
443 // TBMV has the same requirements as TRMV + K >= 0
444 if (K < 0) {
445 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
446 }
447 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
448 int N = A->getType()->getY();
449 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
450 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
451 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
452 }
453
STPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)454 void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
455 sp<Allocation> Ap, sp<Allocation> X, int incX) {
456 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
457 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
458 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
459 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
460 }
461
DTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)462 void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
463 sp<Allocation> Ap, sp<Allocation> X, int incX) {
464 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
465 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
466 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
467 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
468 }
469
CTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)470 void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
471 sp<Allocation> Ap, sp<Allocation> X, int incX) {
472 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
473 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
474 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
475 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
476 }
477
ZTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)478 void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
479 sp<Allocation> Ap, sp<Allocation> X, int incX) {
480 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
481 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
482 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
483 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
484 }
485
STRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)486 void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
487 sp<Allocation> A, sp<Allocation> X, int incX) {
488 // TRSV is the same as TRMV
489 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
490 int N = A->getType()->getY();
491 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
492 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
493 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
494 }
495
DTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)496 void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
497 sp<Allocation> A, sp<Allocation> X, int incX) {
498 // TRSV is the same as TRMV
499 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
500 int N = A->getType()->getY();
501 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
502 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
503 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
504
505 }
506
CTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)507 void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
508 sp<Allocation> A, sp<Allocation> X, int incX) {
509 // TRSV is the same as TRMV
510 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
511 int N = A->getType()->getY();
512 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
513 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
514 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
515
516 }
517
ZTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)518 void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
519 sp<Allocation> A, sp<Allocation> X, int incX) {
520 // TRSV is the same as TRMV
521 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
522 int N = A->getType()->getY();
523 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
524 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
525 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
526
527 }
528
STBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)529 void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
530 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
531 // TBSV is the same as TRMV + K >= 0
532 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
533 int N = A->getType()->getY();
534 if (K < 0) {
535 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
536 }
537 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
538 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
539 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
540 }
541
DTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)542 void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
543 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
544 // TBSV is the same as TRMV + K >= 0
545 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
546 int N = A->getType()->getY();
547 if (K < 0) {
548 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
549 }
550 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
551 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
552 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
553 }
554
CTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)555 void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
556 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
557 // TBSV is the same as TRMV + K >= 0
558 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
559 int N = A->getType()->getY();
560 if (K < 0) {
561 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
562 }
563 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
564 TransA, 0, 0, Uplo, Diag, 0, N, K,
565 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
566 }
567
ZTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)568 void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
569 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
570 // TBSV is the same as TRMV + K >= 0
571 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
572 int N = A->getType()->getY();
573 if (K < 0) {
574 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
575 }
576 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
577 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
578 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
579 }
580
STPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)581 void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
582 sp<Allocation> Ap, sp<Allocation> X, int incX) {
583 // TPSV is same as TPMV
584 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
585 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
586 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
587 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
588 }
589
DTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)590 void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
591 sp<Allocation> Ap, sp<Allocation> X, int incX) {
592 // TPSV is same as TPMV
593 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
594 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
595 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
596 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
597 }
598
CTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)599 void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
600 sp<Allocation> Ap, sp<Allocation> X, int incX) {
601 // TPSV is same as TPMV
602 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
603 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
604 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
605 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
606 }
607
ZTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)608 void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
609 sp<Allocation> Ap, sp<Allocation> X, int incX) {
610 // TPSV is same as TPMV
611 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
612 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
613 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
614 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
615 }
616
617 /**
618 * Level 2, S and D only
619 */
validateSYMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> A,sp<Allocation> X,sp<Allocation> Y,int incX,int incY)620 static int validateSYMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> A,
621 sp<Allocation> X, sp<Allocation> Y, int incX, int incY) {
622 int N = A->getType()->getY();
623 if ((int)A->getType()->getX() != N) {
624 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
625 }
626 if (!A->getType()->getElement()->isCompatible(e) ||
627 !X->getType()->getElement()->isCompatible(e) ||
628 !Y->getType()->getElement()->isCompatible(e) ) {
629 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
630 }
631 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
632 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
633 }
634
635 if (incX <= 0 || incY <= 0) {
636 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
637 }
638 int expectedXDim = 1 + (N - 1) * incX;
639 if ((int)X->getType()->getX() != expectedXDim) {
640 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
641 }
642 int expectedYDim = 1 + (N - 1) * incY;
643 if ((int)Y->getType()->getX() != expectedYDim) {
644 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
645 }
646 return N;
647 }
validateSPMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> Ap,sp<Allocation> X,int incX,sp<Allocation> Y,int incY)648 static int validateSPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> Ap,
649 sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
650 if (!Ap->getType()->getElement()->isCompatible(e) ||
651 !X->getType()->getElement()->isCompatible(e) ||
652 !Y->getType()->getElement()->isCompatible(e)) {
653 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
654 }
655 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
656 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
657 }
658
659 if (Ap->getType()->getY() > 1) {
660 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
661 }
662
663 int N = sqrt((double)Ap->getType()->getX() * 2);
664 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
665 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
666 }
667 if (incX <= 0 || incY <= 0) {
668 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
669 }
670 int expectedXDim = 1 + (N - 1) * incX;
671 if ((int)X->getType()->getX() != expectedXDim) {
672 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
673 }
674 int expectedYDim = 1 + (N - 1) * incY;
675 if ((int)Y->getType()->getX() != expectedYDim) {
676 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
677 }
678
679 return N;
680 }
validateGER(RS * mRS,sp<const Element> e,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)681 static void validateGER(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
682 sp<Allocation> Y, int incY, sp<Allocation> A) {
683 if (!A->getType()->getElement()->isCompatible(e) ||
684 !X->getType()->getElement()->isCompatible(e) ||
685 !Y->getType()->getElement()->isCompatible(e) ) {
686 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
687 }
688
689 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
690 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
691 }
692
693 int M = A->getType()->getY();
694 int N = A->getType()->getX();
695
696 if (N < 1 || M < 1) {
697 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
698 }
699 if (incX <= 0 || incY <= 0) {
700 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
701 }
702 int expectedXDim = 1 + (M - 1) * incX;
703 if ((int)X->getType()->getX() != expectedXDim) {
704 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
705 }
706 int expectedYDim = 1 + (N - 1) * incY;
707 if ((int)Y->getType()->getX() != expectedYDim) {
708 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
709 }
710
711
712 }
validateSYR(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> A)713 static int validateSYR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
714 sp<Allocation> X, int incX, sp<Allocation> A) {
715 if (!A->getType()->getElement()->isCompatible(e) ||
716 !X->getType()->getElement()->isCompatible(e)) {
717 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
718 }
719
720 int N = A->getType()->getX();
721
722 if (X->getType()->getY() > 1) {
723 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
724 }
725 if (N != (int)A->getType()->getY()) {
726 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
727 }
728 if (incX <= 0) {
729 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
730 }
731 int expectedXDim = 1 + (N - 1) * incX;
732 if ((int)X->getType()->getX() != expectedXDim) {
733 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
734 }
735 return N;
736 }
validateSPR(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> Ap)737 static int validateSPR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
738 sp<Allocation> X, int incX, sp<Allocation> Ap) {
739 if (!Ap->getType()->getElement()->isCompatible(e) ||
740 !X->getType()->getElement()->isCompatible(e)) {
741 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
742 }
743 if (X->getType()->getY() > 1) {
744 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
745 }
746
747 if (Ap->getType()->getY() > 1) {
748 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
749 }
750
751 int N = sqrt((double)Ap->getType()->getX() * 2);
752 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
753 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
754 }
755 if (incX <= 0) {
756 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
757 }
758 int expectedXDim = 1 + (N - 1) * incX;
759 if ((int)X->getType()->getX() != expectedXDim) {
760 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
761 }
762
763 return N;
764 }
765
validateSYR2(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)766 static int validateSYR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
767 int incX, sp<Allocation> Y, int incY, sp<Allocation> A) {
768 if (!A->getType()->getElement()->isCompatible(e) ||
769 !X->getType()->getElement()->isCompatible(e) ||
770 !Y->getType()->getElement()->isCompatible(e)) {
771 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
772 }
773
774 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
775 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
776 }
777
778 int N = A->getType()->getX();
779
780 if (N != (int)A->getType()->getY()) {
781 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
782 }
783 if (incX <= 0 || incY <= 0) {
784 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
785 }
786 int expectedXDim = 1 + (N - 1) * incX;
787 int expectedYDim = 1 + (N - 1) * incY;
788 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
789 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
790 }
791 return N;
792
793 }
validateSPR2(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)794 static int validateSPR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
795 int incX, sp<Allocation> Y, int incY, sp<Allocation> Ap) {
796 if (!Ap->getType()->getElement()->isCompatible(e) ||
797 !X->getType()->getElement()->isCompatible(e) ||
798 !Y->getType()->getElement()->isCompatible(e)) {
799 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
800 }
801 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
802 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
803 }
804
805 if (Ap->getType()->getY() > 1) {
806 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
807 }
808
809 int N = sqrt((double)Ap->getType()->getX() * 2);
810 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
811 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
812 }
813 if (incX <= 0 || incY <= 0) {
814 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
815 }
816 int expectedXDim = 1 + (N - 1) * incX;
817 int expectedYDim = 1 + (N - 1) * incY;
818 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
819 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
820 }
821
822 return N;
823 }
824
SSYMV(RsBlasUplo Uplo,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)825 void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp<Allocation> A, sp<Allocation> X,
826 int incX, float beta, sp<Allocation> Y, int incY) {
827 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
828 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
829 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
830 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
831 }
832
SSBMV(RsBlasUplo Uplo,int K,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)833 void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp<Allocation> A, sp<Allocation> X,
834 int incX, float beta, sp<Allocation> Y, int incY) {
835 // SBMV is the same as SYMV + K >= 0
836 if (K < 0) {
837 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
838 }
839 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
840 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
841 0, 0, 0, Uplo, 0, 0, N, K, alpha,
842 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
843 }
844
SSPMV(RsBlasUplo Uplo,float alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)845 void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp<Allocation> Ap, sp<Allocation> X,
846 int incX, float beta, sp<Allocation> Y, int incY) {
847 int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
848 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
849 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
850 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
851 }
852
SGER(float alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)853 void ScriptIntrinsicBLAS::SGER(float alpha, sp<Allocation> X, int incX,
854 sp<Allocation> Y, int incY, sp<Allocation> A) {
855 int M = A->getType()->getY();
856 int N = A->getType()->getX();
857 validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
858 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
859 0, 0, 0, 0, 0, M, N, 0, alpha,
860 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
861 }
862
SSYR(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> A)863 void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
864 int incX, sp<Allocation> A) {
865 int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
866 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
867 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
868 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
869 }
870
SSPR(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)871 void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
872 int incX, sp<Allocation> Ap) {
873 int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
874 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
875 0, 0, 0, Uplo, 0, 0, N, 0,
876 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
877 }
878
SSYR2(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)879 void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
880 sp<Allocation> Y, int incY, sp<Allocation> A) {
881 int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
882 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
883 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
884 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
885 }
886
SSPR2(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)887 void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
888 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
889 int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
890 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
891 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
892 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
893 }
894
DSYMV(RsBlasUplo Uplo,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)895 void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp<Allocation> A, sp<Allocation> X,
896 int incX, double beta, sp<Allocation> Y, int incY) {
897 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
898 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
899 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
900 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
901 }
902
DSBMV(RsBlasUplo Uplo,int K,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)903 void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp<Allocation> A, sp<Allocation> X,
904 int incX, double beta, sp<Allocation> Y, int incY) {
905 // SBMV is the same as SYMV + K >= 0
906 if (K < 0) {
907 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
908 }
909 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
910 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
911 0, 0, 0, Uplo, 0, 0, N, K, alpha,
912 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
913 }
914
DSPMV(RsBlasUplo Uplo,double alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)915 void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp<Allocation> Ap, sp<Allocation> X,
916 int incX, double beta, sp<Allocation> Y, int incY) {
917 int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
918 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
919 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
920 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
921 }
922
DGER(double alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)923 void ScriptIntrinsicBLAS::DGER(double alpha, sp<Allocation> X, int incX, sp<Allocation> Y,
924 int incY, sp<Allocation> A) {
925 int M = A->getType()->getY();
926 int N = A->getType()->getX();
927 validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
928 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
929 0, 0, 0, 0, 0, M, N, 0, alpha,
930 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
931 }
932
DSYR(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> A)933 void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
934 int incX, sp<Allocation> A) {
935 int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
936 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
937 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
938 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
939 }
940
DSPR(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)941 void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
942 int incX, sp<Allocation> Ap) {
943 int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
944 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
945 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
946 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
947 }
948
DSYR2(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)949 void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
950 sp<Allocation> Y, int incY, sp<Allocation> A) {
951 int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
952 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
953 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
954 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
955 }
956
DSPR2(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)957 void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
958 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
959 int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
960 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
961 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
962 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
963 }
964
965
966 /**
967 * Level 2, C and Z only
968 */
969
validateGERU(RS * mRS,sp<const Element> e,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)970 static void validateGERU(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
971 sp<Allocation> Y, int incY, sp<Allocation> A) {
972 if (!A->getType()->getElement()->isCompatible(e) ||
973 !X->getType()->getElement()->isCompatible(e) ||
974 !Y->getType()->getElement()->isCompatible(e)) {
975 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
976 }
977 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
978 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
979 }
980
981 int M = A->getType()->getY();
982 int N = A->getType()->getX();
983 if (incX <= 0 || incY <= 0) {
984 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
985 }
986 int expectedXDim = 1 + (M - 1) * incX;
987 if ((int)X->getType()->getX() != expectedXDim) {
988 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
989 }
990 int expectedYDim = 1 + (N - 1) * incY;
991 if ((int)Y->getType()->getX() != expectedYDim) {
992 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
993 }
994
995 }
996
CHEMV(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)997 void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> A,
998 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
999 // HEMV is the same as SYR2 validation-wise
1000 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1001 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
1002 0, 0, 0, Uplo, 0, 0, N, 0,
1003 alpha.x, alpha.y, A->getID(), X->getID(),
1004 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1005 }
1006
CHBMV(RsBlasUplo Uplo,int K,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)1007 void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp<Allocation> A,
1008 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1009 // HBMV is the same as SYR2 validation-wise
1010 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1011 if (K < 0) {
1012 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1013 }
1014 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
1015 0, 0, 0, Uplo, 0, 0, N, K,
1016 alpha.x, alpha.y, A->getID(), X->getID(),
1017 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1018 }
1019
CHPMV(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)1020 void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> Ap,
1021 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1022 // HPMV is the same as SPR2
1023 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1024 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
1025 0, 0, 0, Uplo, 0, 0, N, 0,
1026 alpha.x, alpha.y, Ap->getID(), X->getID(),
1027 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1028 }
1029
CGERU(Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1030 void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp<Allocation> X, int incX,
1031 sp<Allocation> Y, int incY, sp<Allocation> A) {
1032 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1033 int M = A->getType()->getY();
1034 int N = A->getType()->getX();
1035 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
1036 0, 0, 0, 0, 0, M, N, 0,
1037 alpha.x, alpha.y, X->getID(), Y->getID(),
1038 0, 0, A->getID(), incX, incY, 0, 0);
1039 }
1040
CGERC(Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1041 void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp<Allocation> X, int incX,
1042 sp<Allocation> Y, int incY, sp<Allocation> A) {
1043 // Same as GERU
1044 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1045 int M = A->getType()->getY();
1046 int N = A->getType()->getX();
1047 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
1048 0, 0, 0, 0, 0, M, N, 0,
1049 alpha.x, alpha.y, X->getID(), Y->getID(),
1050 0, 0, A->getID(), incX, incY, 0, 0);
1051 }
1052
CHER(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> A)1053 void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1054 int incX, sp<Allocation> A) {
1055 // Same as SYR
1056 int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
1057 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
1058 0, 0, 0, Uplo, 0, 0, N, 0,
1059 alpha, 0, X->getID(), 0,
1060 0, 0, A->getID(), incX, 0, 0, 0);
1061 }
1062
CHPR(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)1063 void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1064 int incX, sp<Allocation> Ap) {
1065 // Equivalent to SPR for validation
1066 int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
1067 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
1068 0, 0, 0, Uplo, 0, 0, N, 0,
1069 alpha, 0, X->getID(), 0,
1070 0, 0, Ap->getID(), incX, 0, 0, 0);
1071 }
1072
CHER2(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1073 void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1074 sp<Allocation> Y, int incY, sp<Allocation> A) {
1075 // Same as SYR2
1076 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1077 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
1078 0, 0, 0, Uplo, 0, 0, N, 0,
1079 alpha.x, alpha.y, X->getID(), Y->getID(),
1080 0, 0, A->getID(), incX, incY, 0, 0);
1081 }
1082
CHPR2(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)1083 void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1084 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1085 // Same as SPR2
1086 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1087 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
1088 0, 0, 0, Uplo, 0, 0, N, 0,
1089 alpha.x, alpha.y, X->getID(), Y->getID(),
1090 0, 0, Ap->getID(), incX, incY, 0, 0);
1091 }
1092
ZHEMV(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)1093 void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> A,
1094 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
1095 // HEMV is the same as SYR2 validation-wise
1096 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1097 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
1098 0, 0, 0, Uplo, 0, 0, N, 0,
1099 alpha.x, alpha.y, A->getID(), X->getID(),
1100 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1101 }
1102
ZHBMV(RsBlasUplo Uplo,int K,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)1103 void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
1104 int incX, Double2 beta, sp<Allocation> Y, int incY) {
1105 // HBMV is the same as SYR2 validation-wise
1106 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1107 if (K < 0) {
1108 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1109 }
1110 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
1111 0, 0, 0, Uplo, 0, 0, N, K,
1112 alpha.x, alpha.y, A->getID(), X->getID(),
1113 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1114 }
1115
ZHPMV(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)1116 void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> Ap, sp<Allocation> X,
1117 int incX, Double2 beta, sp<Allocation> Y, int incY) {
1118 // HPMV is the same as SPR2
1119 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1120 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
1121 0, 0, 0, Uplo, 0, 0, N, 0,
1122 alpha.x, alpha.y, Ap->getID(), X->getID(),
1123 beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1124 }
1125
ZGERU(Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1126 void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp<Allocation> X, int incX,
1127 sp<Allocation> Y, int incY, sp<Allocation> A) {
1128 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1129 int M = A->getType()->getY();
1130 int N = A->getType()->getX();
1131 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
1132 0, 0, 0, 0, 0, M, N, 0,
1133 alpha.x, alpha.y, X->getID(), Y->getID(),
1134 0, 0, A->getID(), incX, incY, 0, 0);
1135 }
1136
ZGERC(Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1137 void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp<Allocation> X, int incX,
1138 sp<Allocation> Y, int incY, sp<Allocation> A) {
1139 // Same as GERU
1140 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1141 int M = A->getType()->getY();
1142 int N = A->getType()->getX();
1143 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
1144 0, 0, 0, 0, 0, M, N, 0,
1145 alpha.x, alpha.y, X->getID(), Y->getID(),
1146 0, 0, A->getID(), incX, incY, 0, 0);
1147 }
1148
ZHER(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> A)1149 void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1150 int incX, sp<Allocation> A) {
1151 // Same as SYR
1152 int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
1153 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
1154 0, 0, 0, Uplo, 0, 0, N, 0,
1155 alpha, 0, X->getID(), 0,
1156 0, 0, A->getID(), incX, 0, 0, 0);
1157 }
1158
ZHPR(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)1159 void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1160 int incX, sp<Allocation> Ap) {
1161 // Equivalent to SPR for validation
1162 int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
1163 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
1164 0, 0, 0, Uplo, 0, 0, N, 0,
1165 alpha, 0, X->getID(), 0,
1166 0, 0, Ap->getID(), incX, 0, 0, 0);
1167 }
1168
ZHER2(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1169 void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1170 sp<Allocation> Y, int incY, sp<Allocation> A) {
1171 // Same as SYR2
1172 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1173 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
1174 0, 0, 0, Uplo, 0, 0, N, 0,
1175 alpha.x, alpha.y, X->getID(), Y->getID(),
1176 0, 0, A->getID(), incX, incY, 0, 0);
1177 }
1178
ZHPR2(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)1179 void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1180 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1181 // Same as SPR2
1182 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1183 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
1184 0, 0, 0, Uplo, 0, 0, N, 0,
1185 alpha.x, alpha.y, X->getID(), Y->getID(),
1186 0, 0, Ap->getID(), incX, incY, 0, 0);
1187 }
1188
1189
1190 /**
1191 * Level 3 BLAS
1192 */
1193
validateL3(RS * mRS,sp<const Element> e,int TransA,int TransB,int Side,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1194 static void validateL3(RS* mRS, sp<const Element> e, int TransA, int TransB, int Side,
1195 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1196 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
1197 if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
1198 (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
1199 (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
1200 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1201 }
1202 if (C == nullptr) {
1203 // Since matrix C is used to store the result, it cannot be null.
1204 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
1205 }
1206 cM = C->getType()->getY();
1207 cN = C->getType()->getX();
1208
1209 if (Side == RsBlasRight) {
1210 if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
1211 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
1212 }
1213 if (B != nullptr) {
1214 bM = A->getType()->getY();
1215 bN = A->getType()->getX();
1216 }
1217 if (A != nullptr) {
1218 aM = B->getType()->getY();
1219 aN = B->getType()->getX();
1220 }
1221 } else {
1222 if (A != nullptr) {
1223 if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
1224 aN = A->getType()->getY();
1225 aM = A->getType()->getX();
1226 } else {
1227 aM = A->getType()->getY();
1228 aN = A->getType()->getX();
1229 }
1230 }
1231 if (B != nullptr) {
1232 if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
1233 bN = B->getType()->getY();
1234 bM = B->getType()->getX();
1235 } else {
1236 bM = B->getType()->getY();
1237 bN = B->getType()->getX();
1238 }
1239 }
1240 }
1241 if (A != nullptr && B != nullptr && C != nullptr) {
1242 if (aN != bM || aM != cM || bN != cN) {
1243 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1244 }
1245 } else if (A != nullptr && C != nullptr) {
1246 // A and C only, for SYRK
1247 if (cM != cN) {
1248 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
1249 }
1250 if (aM != cM) {
1251 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1252 }
1253 } else if (A != nullptr && B != nullptr) {
1254 // A and B only
1255 if (aN != bM) {
1256 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1257 }
1258 }
1259
1260 }
1261
SGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,float alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1262 void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
1263 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1264 validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
1265
1266 int M = -1, N = -1, K = -1;
1267 if (TransA != RsBlasNoTrans) {
1268 M = A->getType()->getX();
1269 K = A->getType()->getY();
1270 } else {
1271 M = A->getType()->getY();
1272 K = A->getType()->getX();
1273 }
1274 if (TransB != RsBlasNoTrans) {
1275 N = B->getType()->getY();
1276 } else {
1277 N = B->getType()->getX();
1278 }
1279 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
1280 TransA, TransB, 0, 0, 0, M, N, K,
1281 alpha, A->getID(), B->getID(),
1282 beta, C->getID(), 0, 0, 0, 0);
1283 }
1284
DGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,double alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1285 void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
1286 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1287 validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
1288 int M = -1, N = -1, K = -1;
1289 if (TransA != RsBlasNoTrans) {
1290 M = A->getType()->getX();
1291 K = A->getType()->getY();
1292 } else {
1293 M = A->getType()->getY();
1294 K = A->getType()->getX();
1295 }
1296 if (TransB != RsBlasNoTrans) {
1297 N = B->getType()->getY();
1298 } else {
1299 N = B->getType()->getX();
1300 }
1301 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
1302 TransA, TransB, 0, 0, 0, M, N, K,
1303 alpha, A->getID(), B->getID(),
1304 beta, C->getID(), 0, 0, 0, 0);
1305 }
1306
CGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1307 void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
1308 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1309 validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
1310 int M = -1, N = -1, K = -1;
1311 if (TransA != RsBlasNoTrans) {
1312 M = A->getType()->getX();
1313 K = A->getType()->getY();
1314 } else {
1315 M = A->getType()->getY();
1316 K = A->getType()->getX();
1317 }
1318 if (TransB != RsBlasNoTrans) {
1319 N = B->getType()->getY();
1320 } else {
1321 N = B->getType()->getX();
1322 }
1323 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
1324 TransA, TransB, 0, 0, 0, M, N, K,
1325 alpha.x, alpha.y, A->getID(), B->getID(),
1326 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1327 }
1328
ZGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1329 void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
1330 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1331 validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
1332 int M = -1, N = -1, K = -1;
1333 if (TransA != RsBlasNoTrans) {
1334 M = A->getType()->getX();
1335 K = A->getType()->getY();
1336 } else {
1337 M = A->getType()->getY();
1338 K = A->getType()->getX();
1339 }
1340 if (TransB != RsBlasNoTrans) {
1341 N = B->getType()->getY();
1342 } else {
1343 N = B->getType()->getX();
1344 }
1345 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
1346 TransA, TransB, 0, 0, 0, M, N, K,
1347 alpha.x, alpha.y, A->getID(), B->getID(),
1348 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1349 }
1350
SSYMM(RsBlasSide Side,RsBlasUplo Uplo,float alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1351 void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
1352 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1353 //For SYMM, Matrix A should be symmetric
1354 if (A->getType()->getX() != A->getType()->getY()) {
1355 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1356 }
1357 validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
1358 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
1359 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1360 alpha, A->getID(), B->getID(),
1361 beta, C->getID(), 0, 0, 0, 0);
1362 }
1363
DSYMM(RsBlasSide Side,RsBlasUplo Uplo,double alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1364 void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
1365 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1366 if (A->getType()->getX() != A->getType()->getY()) {
1367 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1368 }
1369 validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
1370 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
1371 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1372 alpha, A->getID(), B->getID(),
1373 beta, C->getID(), 0, 0, 0, 0);
1374 }
1375
CSYMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1376 void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1377 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1378 if (A->getType()->getX() != A->getType()->getY()) {
1379 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1380 }
1381 validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
1382 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
1383 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1384 alpha.x, alpha.y, A->getID(), B->getID(),
1385 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1386 }
1387
ZSYMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1388 void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1389 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1390 if (A->getType()->getX() != A->getType()->getY()) {
1391 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1392 }
1393 validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
1394 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
1395 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1396 alpha.x, alpha.y, A->getID(), B->getID(),
1397 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1398 }
1399
SSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,sp<Allocation> A,float beta,sp<Allocation> C)1400 void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1401 sp<Allocation> A, float beta, sp<Allocation> C) {
1402 validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
1403 int K = -1;
1404 if (Trans != RsBlasNoTrans) {
1405 K = A->getType()->getY();
1406 } else {
1407 K = A->getType()->getX();
1408 }
1409 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
1410 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1411 alpha, A->getID(), 0,
1412 beta, C->getID(), 0, 0, 0, 0);
1413 }
1414
DSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,sp<Allocation> A,double beta,sp<Allocation> C)1415 void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1416 sp<Allocation> A, double beta, sp<Allocation> C) {
1417 validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
1418 int K = -1;
1419 if (Trans != RsBlasNoTrans) {
1420 K = A->getType()->getY();
1421 } else {
1422 K = A->getType()->getX();
1423 }
1424 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
1425 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1426 alpha, A->getID(), 0,
1427 beta, C->getID(), 0, 0, 0, 0);
1428 }
1429
CSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,sp<Allocation> A,Float2 beta,sp<Allocation> C)1430 void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1431 sp<Allocation> A, Float2 beta, sp<Allocation> C) {
1432 validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
1433 int K = -1;
1434 if (Trans != RsBlasNoTrans) {
1435 K = A->getType()->getY();
1436 } else {
1437 K = A->getType()->getX();
1438 }
1439 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
1440 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1441 alpha.x, alpha.y, A->getID(), 0,
1442 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1443 }
1444
ZSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,sp<Allocation> A,Double2 beta,sp<Allocation> C)1445 void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1446 sp<Allocation> A, Double2 beta, sp<Allocation> C) {
1447 validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
1448 int K = -1;
1449 if (Trans != RsBlasNoTrans) {
1450 K = A->getType()->getY();
1451 } else {
1452 K = A->getType()->getX();
1453 }
1454 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
1455 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1456 alpha.x, alpha.y, A->getID(), 0,
1457 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1458 }
1459
validateSYR2K(RS * mRS,sp<const Element> e,RsBlasTranspose Trans,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1460 static void validateSYR2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1461 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1462 if (!A->getType()->getElement()->isCompatible(e) ||
1463 !B->getType()->getElement()->isCompatible(e) ||
1464 !C->getType()->getElement()->isCompatible(e)) {
1465 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1466 }
1467 int Cdim = -1;
1468 // A is n x k if no transpose, k x n if transpose
1469 // C is n x n
1470 if (Trans == RsBlasTrans) {
1471 // check columns versus C
1472 Cdim = A->getType()->getX();
1473 } else {
1474 // check rows versus C
1475 Cdim = A->getType()->getY();
1476 }
1477 if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
1478 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
1479 }
1480 // A dims == B dims
1481 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1482 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
1483 }
1484 }
1485
SSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1486 void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1487 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1488 validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
1489 int K = -1;
1490 if (Trans != RsBlasNoTrans) {
1491 K = A->getType()->getY();
1492 } else {
1493 K = A->getType()->getX();
1494 }
1495 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
1496 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1497 alpha, A->getID(), B->getID(),
1498 beta, C->getID(), 0, 0, 0, 0);
1499 }
1500
DSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1501 void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1502 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1503 validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
1504 int K = -1;
1505 if (Trans != RsBlasNoTrans) {
1506 K = A->getType()->getY();
1507 } else {
1508 K = A->getType()->getX();
1509 }
1510 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
1511 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1512 alpha, A->getID(), B->getID(),
1513 beta, C->getID(), 0, 0, 0, 0);
1514 }
1515
CSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1516 void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1517 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1518 validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1519 int K = -1;
1520 if (Trans != RsBlasNoTrans) {
1521 K = A->getType()->getY();
1522 } else {
1523 K = A->getType()->getX();
1524 }
1525 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
1526 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1527 alpha.x, alpha.y, A->getID(), B->getID(),
1528 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1529 }
1530
ZSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1531 void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1532 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1533 validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1534 int K = -1;
1535 if (Trans != RsBlasNoTrans) {
1536 K = A->getType()->getY();
1537 } else {
1538 K = A->getType()->getX();
1539 }
1540 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
1541 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1542 alpha.x, alpha.y, A->getID(), B->getID(),
1543 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1544 }
1545
validateTRMM(RS * mRS,sp<const Element> e,RsBlasSide Side,RsBlasTranspose TransA,sp<Allocation> A,sp<Allocation> B)1546 static void validateTRMM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1547 sp<Allocation> A, sp<Allocation> B) {
1548 int aM = -1, aN = -1, bM = -1, bN = -1;
1549 if (!A->getType()->getElement()->isCompatible(e) ||
1550 !B->getType()->getElement()->isCompatible(e)) {
1551 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1552 }
1553
1554 aM = A->getType()->getY();
1555 aN = A->getType()->getX();
1556 if (aM != aN) {
1557 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
1558 }
1559
1560 bM = B->getType()->getY();
1561 bN = B->getType()->getX();
1562 if (Side == RsBlasLeft) {
1563 if (aN != bM) {
1564 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1565 }
1566 } else {
1567 if (bN != aM) {
1568 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1569 }
1570 }
1571 }
1572
STRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,sp<Allocation> A,sp<Allocation> B)1573 void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1574 float alpha, sp<Allocation> A, sp<Allocation> B) {
1575 validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
1576 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
1577 TransA, 0, Side, Uplo, Diag,\
1578 B->getType()->getY(), B->getType()->getX(), 0,
1579 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
1580 }
1581
DTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,sp<Allocation> A,sp<Allocation> B)1582 void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1583 double alpha, sp<Allocation> A, sp<Allocation> B) {
1584 validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
1585 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
1586 TransA, 0, Side, Uplo, Diag,
1587 B->getType()->getY(), B->getType()->getX(), 0,
1588 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1589 }
1590
CTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,sp<Allocation> A,sp<Allocation> B)1591 void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1592 Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1593 validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1594 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
1595 TransA, 0, Side, Uplo, Diag,
1596 B->getType()->getY(), B->getType()->getX(), 0,
1597 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1598 }
1599
ZTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,sp<Allocation> A,sp<Allocation> B)1600 void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1601 Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1602 validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1603 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
1604 TransA, 0, Side, Uplo, Diag,
1605 B->getType()->getY(), B->getType()->getX(), 0,
1606 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1607 }
1608
validateTRSM(RS * mRS,sp<const Element> e,RsBlasSide Side,RsBlasTranspose TransA,sp<Allocation> A,sp<Allocation> B)1609 static void validateTRSM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1610 sp<Allocation> A, sp<Allocation> B) {
1611 int adim = -1, bM = -1, bN = -1;
1612 if (!A->getType()->getElement()->isCompatible(e) ||
1613 !B->getType()->getElement()->isCompatible(e)) {
1614 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1615 }
1616 adim = A->getType()->getX();
1617 if (adim != (int)A->getType()->getY()) {
1618 // This may be unnecessary, the restriction could potentially be relaxed.
1619 // Allocation A needs to contain at least that symmetric matrix but could theoretically
1620 // be larger for now we assume adapters are sufficient, will reevaluate in the future.
1621 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
1622 }
1623 bM = B->getType()->getY();
1624 bN = B->getType()->getX();
1625 if (Side == RsBlasLeft) {
1626 // A is M*M
1627 if (adim != bM) {
1628 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1629 }
1630 } else {
1631 // A is N*N
1632 if (adim != bN) {
1633 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1634 }
1635 }
1636 }
1637
STRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,sp<Allocation> A,sp<Allocation> B)1638 void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1639 float alpha, sp<Allocation> A, sp<Allocation> B) {
1640 validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
1641 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
1642 TransA, 0, Side, Uplo, Diag,
1643 B->getType()->getY(), B->getType()->getX(), 0,
1644 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1645 }
1646
DTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,sp<Allocation> A,sp<Allocation> B)1647 void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1648 double alpha, sp<Allocation> A, sp<Allocation> B) {
1649 validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
1650 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
1651 TransA, 0, Side, Uplo, Diag,
1652 B->getType()->getY(), B->getType()->getX(), 0,
1653 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1654 }
1655
CTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,sp<Allocation> A,sp<Allocation> B)1656 void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1657 Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1658 validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1659 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
1660 TransA, 0, Side, Uplo, Diag,
1661 B->getType()->getY(), B->getType()->getX(), 0,
1662 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1663 }
1664
ZTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,sp<Allocation> A,sp<Allocation> B)1665 void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1666 Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1667 validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1668 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
1669 TransA, 0, Side, Uplo, Diag,
1670 B->getType()->getY(), B->getType()->getX(), 0,
1671 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1672 }
1673
validateHEMM(RS * mRS,sp<const Element> e,RsBlasSide Side,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1674 static void validateHEMM(RS* mRS, sp<const Element> e, RsBlasSide Side,
1675 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1676 if (!A->getType()->getElement()->isCompatible(e) ||
1677 !B->getType()->getElement()->isCompatible(e) ||
1678 !C->getType()->getElement()->isCompatible(e)) {
1679 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1680 }
1681
1682 // A must be square; can potentially be relaxed similar to TRSM
1683 int adim = A->getType()->getX();
1684 if (adim != (int)A->getType()->getY()) {
1685 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
1686 }
1687 if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
1688 (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
1689 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
1690 }
1691 if (B->getType()->getX() != C->getType()->getX() ||
1692 B->getType()->getY() != C->getType()->getY()) {
1693 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
1694 }
1695 }
1696
CHEMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1697 void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1698 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1699 validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
1700 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
1701 0, 0, Side, Uplo, 0,
1702 C->getType()->getY(), C->getType()->getX(), 0,
1703 alpha.x, alpha.y, A->getID(), B->getID(),
1704 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1705 }
1706
ZHEMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1707 void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1708 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1709 validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
1710 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
1711 0, 0, Side, Uplo, 0,
1712 C->getType()->getY(), C->getType()->getX(), 0,
1713 alpha.x, alpha.y, A->getID(), B->getID(),
1714 beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1715 }
1716
validateHERK(RS * mRS,sp<const Element> e,RsBlasTranspose Trans,sp<Allocation> A,sp<Allocation> C)1717 static void validateHERK(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1718 sp<Allocation> A, sp<Allocation> C) {
1719 if (!A->getType()->getElement()->isCompatible(e) ||
1720 !C->getType()->getElement()->isCompatible(e)) {
1721 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1722 }
1723 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1724 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1725 }
1726 int cdim = C->getType()->getX();
1727 if (cdim != (int)C->getType()->getY()) {
1728 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
1729 }
1730 if (Trans == RsBlasNoTrans) {
1731 if (cdim != (int)A->getType()->getY()) {
1732 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1733 }
1734 } else {
1735 if (cdim != (int)A->getType()->getX()) {
1736 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1737 }
1738 }
1739 }
1740
CHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,sp<Allocation> A,float beta,sp<Allocation> C)1741 void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1742 sp<Allocation> A, float beta, sp<Allocation> C) {
1743 validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
1744 int k = 0;
1745 if (Trans == RsBlasConjTrans) {
1746 k = A->getType()->getY();
1747 } else {
1748 k = A->getType()->getX();
1749 }
1750 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
1751 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1752 alpha, 0, A->getID(), 0,
1753 beta, 0, C->getID(), 0, 0, 0, 0);
1754 }
1755
ZHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,sp<Allocation> A,double beta,sp<Allocation> C)1756 void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1757 sp<Allocation> A, double beta, sp<Allocation> C) {
1758 validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
1759 int k = 0;
1760 if (Trans == RsBlasConjTrans) {
1761 k = A->getType()->getY();
1762 } else {
1763 k = A->getType()->getX();
1764 }
1765 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
1766 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1767 alpha, 0, A->getID(), 0,
1768 beta, 0, C->getID(), 0, 0, 0, 0);
1769 }
1770
validateHER2K(RS * mRS,sp<const Element> e,RsBlasTranspose Trans,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1771 static void validateHER2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1772 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1773 if (!A->getType()->getElement()->isCompatible(e) ||
1774 !B->getType()->getElement()->isCompatible(e) ||
1775 !C->getType()->getElement()->isCompatible(e)) {
1776 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1777 }
1778 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1779 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1780 }
1781 int cdim = C->getType()->getX();
1782 if (cdim != (int)C->getType()->getY()) {
1783 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
1784 }
1785 if (Trans == RsBlasNoTrans) {
1786 if ((int)A->getType()->getY() != cdim) {
1787 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1788 }
1789 } else {
1790 if ((int)A->getType()->getX() != cdim) {
1791 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1792 }
1793 }
1794 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1795 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
1796 }
1797 }
1798
CHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1799 void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1800 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1801 validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1802 int k = 0;
1803 if (Trans == RsBlasNoTrans) {
1804 k = A->getType()->getX();
1805 } else {
1806 k = A->getType()->getY();
1807 }
1808 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
1809 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1810 alpha.x, alpha.y, A->getID(), B->getID(),
1811 beta, 0, C->getID(), 0, 0, 0, 0);
1812 }
1813
ZHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1814 void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1815 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1816 validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1817 int k = 0;
1818 if (Trans == RsBlasNoTrans) {
1819 k = A->getType()->getX();
1820 } else {
1821 k = A->getType()->getY();
1822 }
1823 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
1824 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1825 alpha.x, alpha.y, A->getID(), B->getID(),
1826 beta, 0, C->getID(), 0, 0, 0, 0);
1827 }
1828
1829
1830
BNNM(sp<Allocation> A,int a_offset,sp<Allocation> B,int b_offset,sp<Allocation> C,int c_offset,int c_mult)1831 void ScriptIntrinsicBLAS::BNNM(sp<Allocation> A, int a_offset, sp<Allocation> B, int b_offset,
1832 sp<Allocation> C, int c_offset, int c_mult) {
1833 validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
1834
1835 if (a_offset < 0 || a_offset > 255) {
1836 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
1837 }
1838 if (b_offset < 0 || b_offset > 255) {
1839 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
1840 }
1841 int M = -1, N = -1, K = -1;
1842 M = A->getType()->getY();
1843 N = B->getType()->getY();
1844 K = A->getType()->getX();
1845
1846 nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
1847 B->getID(), b_offset, C->getID(), c_offset, c_mult);
1848 }
1849