From 42608c87ea2f27e76bc044f2ec4fde2f3684b20d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 21 Oct 2023 22:19:43 +0200 Subject: [PATCH] Fix cmake dtype check (#989) * Fix instances dtype check * Fix source dtypes seletor for examples and tests * Sync with new cmakefile changes * Remove not needed ifdefs * Remove not needed ifdefs [ROCm/composable_kernel commit: ac0e006766185e5fb6342197f4a00599551beb26] --- example/CMakeLists.txt | 92 +++++++++---------- .../gpu/CMakeLists.txt | 4 + .../grouped_conv3d_bwd_weight/CMakeLists.txt | 8 +- test/CMakeLists.txt | 92 +++++++++---------- test/data_type/CMakeLists.txt | 6 +- test/data_type/{bf8.cpp => test_bf8.cpp} | 0 test/data_type/{fp8.cpp => test_fp8.cpp} | 0 test/data_type/{int4.cpp => test_int4.cpp} | 0 8 files changed, 97 insertions(+), 105 deletions(-) rename test/data_type/{bf8.cpp => test_bf8.cpp} (100%) rename test/data_type/{fp8.cpp => test_fp8.cpp} (100%) rename test/data_type/{int4.cpp => test_int4.cpp} (100%) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5a53982bb2..c19ba93b69 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -11,31 +11,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) if(DEFINED DTYPES) foreach(source IN LISTS FILE_NAME) set(test 0) - foreach(type IN LISTS DTYPES) - if(type MATCHES "fp16") - set(type1 "_f16") - elseif(type MATCHES "fp32") - set(type1 "_f32") - elseif(type MATCHES "fp8") - set(type1 "_f8") - elseif(type MATCHES "bf16") - set(type1 "_b16") - elseif(type MATCHES "fp64") - set(type1 "_f64") - elseif(type MATCHES "int8") - set(type1 "_i8") - endif() - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") - #if filename matches any selected type, exit type loop and do no exclude the file from the list - set(test 0) - break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) - #if filename contains a type which doesn't match any selected type, mark it for removal - set(test 1) - endif() - endforeach() + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() if(test EQUAL 1) message("removing example source file ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") @@ -74,31 +70,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) if(DEFINED DTYPES) foreach(source IN LISTS FILE_NAME) set(test 0) - foreach(type IN LISTS DTYPES) - if(type MATCHES "fp16") - set(type1 "_f16") - elseif(type MATCHES "fp32") - set(type1 "_f32") - elseif(type MATCHES "fp8") - set(type1 "_f8") - elseif(type MATCHES "bf16") - set(type1 "_b16") - elseif(type MATCHES "fp64") - set(type1 "_f64") - elseif(type MATCHES "int8") - set(type1 "_i8") - endif() - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") - #if filename matches any selected type, exit type loop and do no exclude the file from the list - set(test 0) - break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) - #if filename contains a type which doesn't match any selected type, mark it for removal - set(test 1) - endif() - endforeach() + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() if(test EQUAL 1) message("removing example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index e1952a886c..9cb5d0e9aa 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -69,6 +69,10 @@ FOREACH(subdir_path ${dir_list}) message("fp8 instance found!") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "_bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8") + message("bf8 instance found!") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") message("fp16 instance found!") set(add_inst 1) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index a2c4b1f80b..968e8dea2f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -4,8 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT @@ -27,4 +26,9 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_BWD_WEIGHT + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT}) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 265f428b55..1567d8bc69 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -13,31 +13,27 @@ function(add_test_executable TEST_NAME) if(DEFINED DTYPES) foreach(source IN LISTS ARGN) set(test 0) - foreach(type IN LISTS DTYPES) - if(type MATCHES "fp16") - set(type1 "_f16") - elseif(type MATCHES "fp32") - set(type1 "_f32") - elseif(type MATCHES "fp8") - set(type1 "_f8") - elseif(type MATCHES "bf16") - set(type1 "_b16") - elseif(type MATCHES "fp64") - set(type1 "_f64") - elseif(type MATCHES "int8") - set(type1 "_i8") - endif() - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") - #if filename matches any selected type, exit type loop and do no exclude the file from the list - set(test 0) - break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) - #if filename contains a type which doesn't match any selected type, mark it for removal - set(test 1) - endif() - endforeach() + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() if(test EQUAL 1) message("removing test ${source} ") list(REMOVE_ITEM ARGN "${source}") @@ -72,31 +68,27 @@ function(add_gtest_executable TEST_NAME) if(DEFINED DTYPES) foreach(source IN LISTS ARGN) set(test 0) - foreach(type IN LISTS DTYPES) - if(type MATCHES "fp16") - set(type1 "_f16") - elseif(type MATCHES "fp32") - set(type1 "_f32") - elseif(type MATCHES "fp8") - set(type1 "_f8") - elseif(type MATCHES "bf16") - set(type1 "_b16") - elseif(type MATCHES "fp64") - set(type1 "_f64") - elseif(type MATCHES "int8") - set(type1 "_i8") - endif() - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") - #if filename matches any selected type, exit type loop and do no exclude the file from the list - set(test 0) - break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) - #if filename contains a type which doesn't match any selected type, mark it for removal - set(test 1) - endif() - endforeach() + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() if(test EQUAL 1) message("removing gtest ${source} ") list(REMOVE_ITEM ARGN "${source}") diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 2409ca05c2..0ebfc931ac 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -1,15 +1,15 @@ if (USE_BITINT_EXTENSION_INT4) - add_gtest_executable(test_int4 int4.cpp) + add_gtest_executable(test_int4 test_int4.cpp) if(result EQUAL 0) target_link_libraries(test_int4 PRIVATE utility) endif() endif() -add_gtest_executable(test_fp8 fp8.cpp) +add_gtest_executable(test_fp8 test_fp8.cpp) if(result EQUAL 0) target_link_libraries(test_fp8 PRIVATE utility) endif() -add_gtest_executable(test_bf8 bf8.cpp) +add_gtest_executable(test_bf8 test_bf8.cpp) if(result EQUAL 0) target_link_libraries(test_bf8 PRIVATE utility) endif() diff --git a/test/data_type/bf8.cpp b/test/data_type/test_bf8.cpp similarity index 100% rename from test/data_type/bf8.cpp rename to test/data_type/test_bf8.cpp diff --git a/test/data_type/fp8.cpp b/test/data_type/test_fp8.cpp similarity index 100% rename from test/data_type/fp8.cpp rename to test/data_type/test_fp8.cpp diff --git a/test/data_type/int4.cpp b/test/data_type/test_int4.cpp similarity index 100% rename from test/data_type/int4.cpp rename to test/data_type/test_int4.cpp