int gpu_index; #ifdef GPU #include #include #include #include #include #ifdef CLBLAS #include #endif #include "opencl.h" #include "utils.h" #include "activations.h" cl_info cl = {0}; void check_error(cl_info info) { // clFinish(cl.queue); if (info.error != CL_SUCCESS) { printf("\n Error number %d", info.error); abort(); exit(1); } } #define MAX_DEVICES 10 cl_info cl_init(int index) { cl_info info; info.initialized = 0; if(index < 0) error("Won't initialize negative gpu id\n"); cl_uint num_platforms, num_devices; // Fetch the Platform and Device IDs; we only want one. cl_device_id devices[MAX_DEVICES]; info.error=clGetPlatformIDs(1, &info.platform, &num_platforms); check_error(info); info.error=clGetDeviceIDs(info.platform, CL_DEVICE_TYPE_ALL, MAX_DEVICES, devices, &num_devices); check_error(info); index = index%num_devices; info.device = devices[index]; check_error(info); cl_context_properties properties[]={ CL_CONTEXT_PLATFORM, (cl_context_properties)info.platform, 0}; // Note that nVidia's OpenCL requires the platform property info.context=clCreateContext(properties, 1, &info.device, 0, 0, &info.error); check_error(info); info.queue = clCreateCommandQueue(info.context, info.device, 0, &info.error); check_error(info); #ifdef CLBLAS info.error = clblasSetup(); #endif check_error(info); info.initialized = 1; #ifdef VERBOSE printf("=== %d OpenCL platform(s) found: ===\n", num_platforms); char buffer[10240]; clGetPlatformInfo(info.platform, CL_PLATFORM_PROFILE, 10240, buffer, NULL); printf(" PROFILE = %s\n", buffer); clGetPlatformInfo(info.platform, CL_PLATFORM_VERSION, 10240, buffer, NULL); printf(" VERSION = %s\n", buffer); clGetPlatformInfo(info.platform, CL_PLATFORM_NAME, 10240, buffer, NULL); printf(" NAME = %s\n", buffer); clGetPlatformInfo(info.platform, CL_PLATFORM_VENDOR, 10240, buffer, NULL); printf(" VENDOR = %s\n", buffer); clGetPlatformInfo(info.platform, CL_PLATFORM_EXTENSIONS, 10240, buffer, NULL); printf(" EXTENSIONS = %s\n", buffer); check_error(info); if(num_devices > MAX_DEVICES) num_devices = MAX_DEVICES; printf("=== %d OpenCL device(s) found on platform:\n", num_devices); int i; for (i=0; i