1 /*===--------------------------------------------------------------------------
2  *              ATMI (Asynchronous Task and Memory Interface)
3  *
4  * This file is distributed under the MIT License. See LICENSE.txt for details.
5  *===------------------------------------------------------------------------*/
6 #include "atmi_interop_hsa.h"
7 #include "internal.h"
8 
9 using core::atl_is_atmi_initialized;
10 
atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,const char * symbol,void ** var_addr,unsigned int * var_size)11 atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
12                                                const char *symbol,
13                                                void **var_addr,
14                                                unsigned int *var_size) {
15   /*
16      // Typical usage:
17      void *var_addr;
18      size_t var_size;
19      atmi_interop_hsa_get_symbol_addr(gpu_place, "symbol_name", &var_addr,
20      &var_size);
21      atmi_memcpy(signal, host_add, var_addr, var_size);
22   */
23 
24   if (!atl_is_atmi_initialized())
25     return ATMI_STATUS_ERROR;
26   atmi_machine_t *machine = atmi_machine_get_info();
27   if (!symbol || !var_addr || !var_size || !machine)
28     return ATMI_STATUS_ERROR;
29   if (place.dev_id < 0 ||
30       place.dev_id >= machine->device_count_by_type[place.dev_type])
31     return ATMI_STATUS_ERROR;
32 
33   // get the symbol info
34   std::string symbolStr = std::string(symbol);
35   if (SymbolInfoTable[place.dev_id].find(symbolStr) !=
36       SymbolInfoTable[place.dev_id].end()) {
37     atl_symbol_info_t info = SymbolInfoTable[place.dev_id][symbolStr];
38     *var_addr = reinterpret_cast<void *>(info.addr);
39     *var_size = info.size;
40     return ATMI_STATUS_SUCCESS;
41   } else {
42     *var_addr = NULL;
43     *var_size = 0;
44     return ATMI_STATUS_ERROR;
45   }
46 }
47 
atmi_interop_hsa_get_kernel_info(atmi_mem_place_t place,const char * kernel_name,hsa_executable_symbol_info_t kernel_info,uint32_t * value)48 atmi_status_t atmi_interop_hsa_get_kernel_info(
49     atmi_mem_place_t place, const char *kernel_name,
50     hsa_executable_symbol_info_t kernel_info, uint32_t *value) {
51   /*
52      // Typical usage:
53      uint32_t value;
54      atmi_interop_hsa_get_kernel_addr(gpu_place, "kernel_name",
55                                   HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE,
56                                   &val);
57   */
58 
59   if (!atl_is_atmi_initialized())
60     return ATMI_STATUS_ERROR;
61   atmi_machine_t *machine = atmi_machine_get_info();
62   if (!kernel_name || !value || !machine)
63     return ATMI_STATUS_ERROR;
64   if (place.dev_id < 0 ||
65       place.dev_id >= machine->device_count_by_type[place.dev_type])
66     return ATMI_STATUS_ERROR;
67 
68   atmi_status_t status = ATMI_STATUS_SUCCESS;
69   // get the kernel info
70   std::string kernelStr = std::string(kernel_name);
71   if (KernelInfoTable[place.dev_id].find(kernelStr) !=
72       KernelInfoTable[place.dev_id].end()) {
73     atl_kernel_info_t info = KernelInfoTable[place.dev_id][kernelStr];
74     switch (kernel_info) {
75     case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE:
76       *value = info.group_segment_size;
77       break;
78     case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE:
79       *value = info.private_segment_size;
80       break;
81     case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE:
82       // return the size for non-implicit args
83       *value = info.kernel_segment_size - sizeof(atmi_implicit_args_t);
84       break;
85     default:
86       *value = 0;
87       status = ATMI_STATUS_ERROR;
88       break;
89     }
90   } else {
91     *value = 0;
92     status = ATMI_STATUS_ERROR;
93   }
94 
95   return status;
96 }
97