diff --git a/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp b/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp index 88a2d06901..6e2857e8a1 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp +++ b/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp @@ -41,4 +41,10 @@ int get_cu_count() return dev_prop.multiProcessorCount; } -int main() { return get_cu_count(); } +int main() +{ + + std::cout << get_cu_count(); + + return 0; +} diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake index 4f18b5dcbe..148d57976a 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake @@ -29,30 +29,48 @@ function(get_cu_count cu_count_arg) execute_process( COMMAND ${CMAKE_HIP_COMPILER} -x hip ${CPP_FILE_PATH} -o ${CPP_EXE_PATH} - RESULT_VARIABLE compile_result + RESULT_VARIABLE compile_exit_code ) - if (NOT compile_result EQUAL 0) + if (NOT compile_exit_code EQUAL 0) message(FATAL_ERROR "Compilation of ${CPP_FILE_PATH} failed.\n") endif() + # Get the HIP library directory + get_filename_component(HIP_COMPILER_DIR ${CMAKE_HIP_COMPILER} DIRECTORY) + get_filename_component(HIP_ROOT_DIR ${HIP_COMPILER_DIR} DIRECTORY) + set(HIP_LIB_DIR "${HIP_ROOT_DIR}/lib") + + # Set library path for runtime execution + if(WIN32) + set(ENV{PATH} "${HIP_LIB_DIR};$ENV{PATH}") + else() + set(ENV{LD_LIBRARY_PATH} "${HIP_LIB_DIR}:$ENV{LD_LIBRARY_PATH}") + endif() + execute_process( COMMAND ${CPP_EXE_PATH} OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_VARIABLE standard_error - RESULT_VARIABLE queried_cu_count + OUTPUT_VARIABLE queried_cu_count + RESULT_VARIABLE queried_cu_count_exit_code ) if (standard_error) message(STATUS "Error information from attempting to query HIP device and properties:\n" "${standard_error}") endif() + + if (NOT queried_cu_count_exit_code EQUAL 0) + message(STATUS "Failed to run ${CPP_EXE_PATH} to query the device's CU count") + + endif() # Delete the generated cu_count executable file(REMOVE "${CPP_EXE_PATH}") - if(queried_cu_count EQUAL 0) + if((queried_cu_count STREQUAL "0") OR (NOT queried_cu_count_exit_code EQUAL 0)) message(WARNING "Unable to query the number of Compute Units. \ Please use the CU_COUNT CLI option to pass in the \ number of Compute Units for your target device; otherwise, \