1 //===--- opencl_example.cpp - Example of using Acxxel with OpenCL ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file is an example of using OpenCL with Acxxel.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "acxxel.h"
14 
15 #include <array>
16 #include <cstdio>
17 #include <cstring>
18 
19 static const char *SaxpyKernelSource = R"(
20 __kernel void saxpyKernel(float A, __global float *X, __global float *Y, int N) {
21   int I = get_global_id(0);
22   if (I < N)
23     X[I] = A * X[I] + Y[I];
24 }
25 )";
26 
27 template <size_t N>
saxpy(float A,std::array<float,N> & X,const std::array<float,N> & Y)28 void saxpy(float A, std::array<float, N> &X, const std::array<float, N> &Y) {
29   acxxel::Platform *OpenCL = acxxel::getOpenCLPlatform().getValue();
30   acxxel::Stream Stream = OpenCL->createStream().takeValue();
31   auto DeviceX = OpenCL->mallocD<float>(N).takeValue();
32   auto DeviceY = OpenCL->mallocD<float>(N).takeValue();
33   Stream.syncCopyHToD(X, DeviceX).syncCopyHToD(Y, DeviceY);
34   acxxel::Program Program =
35       OpenCL
36           ->createProgramFromSource(acxxel::Span<const char>(
37               SaxpyKernelSource, std::strlen(SaxpyKernelSource)))
38           .takeValue();
39   acxxel::Kernel Kernel = Program.createKernel("saxpyKernel").takeValue();
40   float *RawX = static_cast<float *>(DeviceX);
41   float *RawY = static_cast<float *>(DeviceY);
42   int IntLength = N;
43   void *Arguments[] = {&A, &RawX, &RawY, &IntLength};
44   size_t ArgumentSizes[] = {sizeof(float), sizeof(float *), sizeof(float *),
45                             sizeof(int)};
46   acxxel::Status Status =
47       Stream.asyncKernelLaunch(Kernel, N, Arguments, ArgumentSizes)
48           .syncCopyDToH(DeviceX, X)
49           .sync();
50   if (Status.isError()) {
51     std::fprintf(stderr, "Error during saxpy: %s\n",
52                  Status.getMessage().c_str());
53     std::exit(EXIT_FAILURE);
54   }
55 }
56 
main()57 int main() {
58   float A = 2.f;
59   std::array<float, 3> X{{0.f, 1.f, 2.f}};
60   std::array<float, 3> Y{{3.f, 4.f, 5.f}};
61   std::array<float, 3> Expected{{3.f, 6.f, 9.f}};
62   saxpy(A, X, Y);
63   for (int I = 0; I < 3; ++I)
64     if (X[I] != Expected[I]) {
65       std::fprintf(stderr, "Mismatch at position %d, %f != %f\n", I, X[I],
66                    Expected[I]);
67       std::exit(EXIT_FAILURE);
68     }
69 }
70