1 //===--- opencl_example.cpp - Example of using Acxxel with OpenCL ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file is an example of using OpenCL with Acxxel.
10 ///
11 //===----------------------------------------------------------------------===//
12
13 #include "acxxel.h"
14
15 #include <array>
16 #include <cstdio>
17 #include <cstring>
18
19 static const char *SaxpyKernelSource = R"(
20 __kernel void saxpyKernel(float A, __global float *X, __global float *Y, int N) {
21 int I = get_global_id(0);
22 if (I < N)
23 X[I] = A * X[I] + Y[I];
24 }
25 )";
26
27 template <size_t N>
saxpy(float A,std::array<float,N> & X,const std::array<float,N> & Y)28 void saxpy(float A, std::array<float, N> &X, const std::array<float, N> &Y) {
29 acxxel::Platform *OpenCL = acxxel::getOpenCLPlatform().getValue();
30 acxxel::Stream Stream = OpenCL->createStream().takeValue();
31 auto DeviceX = OpenCL->mallocD<float>(N).takeValue();
32 auto DeviceY = OpenCL->mallocD<float>(N).takeValue();
33 Stream.syncCopyHToD(X, DeviceX).syncCopyHToD(Y, DeviceY);
34 acxxel::Program Program =
35 OpenCL
36 ->createProgramFromSource(acxxel::Span<const char>(
37 SaxpyKernelSource, std::strlen(SaxpyKernelSource)))
38 .takeValue();
39 acxxel::Kernel Kernel = Program.createKernel("saxpyKernel").takeValue();
40 float *RawX = static_cast<float *>(DeviceX);
41 float *RawY = static_cast<float *>(DeviceY);
42 int IntLength = N;
43 void *Arguments[] = {&A, &RawX, &RawY, &IntLength};
44 size_t ArgumentSizes[] = {sizeof(float), sizeof(float *), sizeof(float *),
45 sizeof(int)};
46 acxxel::Status Status =
47 Stream.asyncKernelLaunch(Kernel, N, Arguments, ArgumentSizes)
48 .syncCopyDToH(DeviceX, X)
49 .sync();
50 if (Status.isError()) {
51 std::fprintf(stderr, "Error during saxpy: %s\n",
52 Status.getMessage().c_str());
53 std::exit(EXIT_FAILURE);
54 }
55 }
56
main()57 int main() {
58 float A = 2.f;
59 std::array<float, 3> X{{0.f, 1.f, 2.f}};
60 std::array<float, 3> Y{{3.f, 4.f, 5.f}};
61 std::array<float, 3> Expected{{3.f, 6.f, 9.f}};
62 saxpy(A, X, Y);
63 for (int I = 0; I < 3; ++I)
64 if (X[I] != Expected[I]) {
65 std::fprintf(stderr, "Mismatch at position %d, %f != %f\n", I, X[I],
66 Expected[I]);
67 std::exit(EXIT_FAILURE);
68 }
69 }
70