diff --git a/CMakeLists.txt b/CMakeLists.txt index b09da41a83..1d4ac3f14f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,12 +32,10 @@ if (DTYPES) if (DTYPES MATCHES "fp8") add_definitions(-DCK_ENABLE_FP8) set(CK_ENABLE_FP8 "ON") - add_compile_options(-Wno-bit-int-extension) endif() if (DTYPES MATCHES "bf8") add_definitions(-DCK_ENABLE_BF8) set(CK_ENABLE_BF8 "ON") - add_compile_options(-Wno-bit-int-extension) endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) @@ -59,9 +57,11 @@ if (DTYPES) else() add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) set(CK_ENABLE_ALL_DTYPES "ON") - add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8 endif() +#for f8/bf8_t type +add_compile_options(-Wno-bit-int-extension) + if(DL_KERNELS) add_definitions(-DDL_KERNELS) set(CK_ENABLE_DL_KERNELS "ON") diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 703d0e3834..42d7311d3d 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,82 +1,60 @@ add_custom_target(example_gemm_dl) add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_dl example_gemm_dl_fp32) -endif() +add_example_dependencies(example_gemm_dl example_gemm_dl_fp32) + add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_dl example_gemm_dl_fp16) -endif() +add_example_dependencies(example_gemm_dl example_gemm_dl_fp16) + add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) + add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_dl example_gemm_dl_int8) -endif() +add_example_dependencies(example_gemm_dl example_gemm_dl_int8) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int4) + add_example_dependencies(example_gemm_dl example_gemm_dl_int4) endif(USE_BITINT_EXTENSION_INT4) add_custom_target(example_gemm_xdl) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) + add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") add_custom_target(example_gemm_wmma) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) - endif() + add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) endif() add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) - add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) -endif() +add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_int8) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_int4) + add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) - add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp8) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) -endif() +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 9fe833dda0..3486b15571 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,28 +1,24 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_add_add_fastgelu_xdl) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_add_add_fastgelu_xdl) + add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) + + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + + add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + set(target 1) endif() - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - endif() - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - endif() - set(target 1) - endif() -endforeach() \ No newline at end of file +endforeach() diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 150d146e31..222a3b7c0b 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,28 +1,25 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - if(result EQUAL 0) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - endif() - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - if(result EQUAL 0) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - endif() - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - endif() - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - if(result EQUAL 0) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() \ No newline at end of file + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_fwd_reduce_xdl) + + add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) + + add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) + + add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) + + add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index f2c76b76f2..84040fcf5c 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,44 +1,32 @@ add_custom_target(example_grouped_gemm_xdl) add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) + add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16) + add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16) + add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16) + add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16) + add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16) + add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) + add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) + add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) -endif() +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) - endif() + add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) + add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 67cf5666d7..5955e1d6cb 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,62 +1,48 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_reduce_xdl) - add_custom_target(example_gemm_reduce_xdl_max) - add_custom_target(example_gemm_reduce_xdl_mean_meansquare) - add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - endif() - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - endif() - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - endif() + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_reduce_xdl) + add_custom_target(example_gemm_reduce_xdl_max) + add_custom_target(example_gemm_reduce_xdl_mean_meansquare) + add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) - endif() - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) - endif() + add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) - endif() - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) - endif() + add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) - endif() - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) - endif() - - add_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) + add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - if(result EQUAL 0) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) - endif() - endif() - set(target 1) - endif() + add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + + add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) + + add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + + add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) + + add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + + add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) + + add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) + endif() + set(target 1) + endif() endforeach() diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index 1ecbab5825..2b0c4a28ce 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -2,36 +2,30 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) - endif() - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - endif() - if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) + + add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) + + if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) + endif() + set(target 1) + endif() + + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) + add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) + set(target 1) endif() - endif() - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) - endif() - set(target 1) - endif() endforeach() add_custom_target(example_grouped_conv_bwd_weight_dl) + add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) -endif() +add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index f401f7187f..44585b11d0 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,22 +1,18 @@ add_custom_target(example_cgemm_xdl) add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) -if(result EQUAL 0) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) -endif() +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) + add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) -endif() +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) + add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) -if(result EQUAL 0) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) -endif() +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) + add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) -if(result EQUAL 0) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) -endif() +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) + add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) + add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) endif() diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index c80540de7d..4cb45be7c9 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,23 +1,18 @@ add_custom_target(example_batched_gemm_xdl) + add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) -if(result EQUAL 0) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) -endif() +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) + add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) -endif() +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) + add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) -if(result EQUAL 0) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16) -endif() +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16) + add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) -if(result EQUAL 0) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) -endif() +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) - if(result EQUAL 0) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) - endif() + add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) + add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) endif() diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 4ab5bd75f2..3a8c2ef52f 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_custom_target(example_grouped_conv_fwd_multiple_d) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) - endif() - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - endif() - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - endif() - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - endif() - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) - endif() - endif() # USE_BITINT_EXTENSION_INT4 + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_custom_target(example_grouped_conv_fwd_multiple_d) - set(target 1) - endif() + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) + + add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) + + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) + + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) + + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) + endif() # USE_BITINT_EXTENSION_INT4 + + set(target 1) + endif() endforeach() set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - set(target 1) - endif() + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) + set(target 1) + endif() endforeach() diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 25eb44ae59..2a24abf094 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,31 +1,23 @@ add_custom_target(example_gemm_scale_softmax_gemm) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) -endif() -add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) -endif() -add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) -endif() -add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) -endif() -add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) -endif() -add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) -endif() -add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) -endif() +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 8970a57648..eff6b6f3fa 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -4,28 +4,23 @@ foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_splitK_gemm_xdl) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - if(result EQUAL 0) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) - endif() - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) - endif() - add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) - if(result EQUAL 0) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) - endif() - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - if(result EQUAL 0) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - endif() + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) + + add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + + add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) + + add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - if(result EQUAL 0) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) - endif() + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) endif() + set(target 1) endif() endforeach() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 96d1a6c3c0..1ae179e950 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -2,27 +2,26 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) - endif() - add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) - endif() - set(target 1) - endif() + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) + + add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) + + add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) + + set(target 1) + endif() endforeach() foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) - if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) - endif() - set(target 1) - endif() + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) + + add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) + + set(target 1) + endif() endforeach() diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index bcf47b4926..8b850c89a9 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,14 +1,10 @@ add_custom_target(example_permute) add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_permute example_permute_1xHxW_fp16) -endif() +add_example_dependencies(example_permute example_permute_1xHxW_fp16) + add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_permute example_permute_NxHxW_fp16) -endif() +add_example_dependencies(example_permute example_permute_NxHxW_fp16) + add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) -if(result EQUAL 0) - add_dependencies(example_permute example_permute_HxWx4_fp16) -endif() +add_example_dependencies(example_permute example_permute_HxWx4_fp16) diff --git a/example/52_im2col_col2im/CMakeLists.txt b/example/52_im2col_col2im/CMakeLists.txt index a2dec9e805..4dc6c8b4e0 100644 --- a/example/52_im2col_col2im/CMakeLists.txt +++ b/example/52_im2col_col2im/CMakeLists.txt @@ -1,12 +1,15 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_im2col_col2im) - add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) - add_dependencies(example_im2col_col2im example_image_to_column_f32) - add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) - add_dependencies(example_im2col_col2im example_column_to_image_f32) - set(target 1) - endif() + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_im2col_col2im) + + add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) + add_example_dependencies(example_im2col_col2im example_image_to_column_f32) + + add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) + add_example_dependencies(example_im2col_col2im example_column_to_image_f32) + + set(target 1) + endif() endforeach() diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 7f8704f281..5a53982bb2 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -62,6 +62,12 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(result EQUAL 0) + add_dependencies(${EXAMPLE_NAME} ${FILE_NAME}) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 05cd5c5b8d..23fdd23e10 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -113,7 +113,6 @@ struct PassThrough } #endif -#if defined CK_ENABLE_FP8 template <> __host__ __device__ void operator()(f8_t& y, const f8_t& x) const { @@ -143,9 +142,7 @@ struct PassThrough { y = type_convert(x); } -#endif -#if defined CK_ENABLE_BF8 template <> __host__ __device__ void operator()(bf8_t& y, const bf8_t& x) const { @@ -175,7 +172,6 @@ struct PassThrough { y = ck::type_convert(x); } -#endif }; struct UnaryConvert @@ -204,7 +200,6 @@ struct ConvertBF16RTN } }; -#if defined CK_ENABLE_FP8 struct ConvertF8SR { // convert to fp8 using stochastic rounding (SR) @@ -221,7 +216,6 @@ struct ConvertF8SR y = f8_convert_sr(x); } }; -#endif struct Scale { diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index c8e56fbc56..835075b7f2 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -462,7 +462,6 @@ struct mfma_type } }; -#if defined CK_ENABLE_FP8 template <> struct mfma_type { @@ -506,9 +505,7 @@ struct mfma_type intrin_mfma_f32_16x16x32f8f8::Run(a, b, reg_c); } }; -#endif -#if defined CK_ENABLE_BF8 template <> struct mfma_type { @@ -552,9 +549,7 @@ struct mfma_type intrin_mfma_f32_16x16x32bf8bf8::Run(a, b, reg_c); } }; -#endif -#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> struct mfma_type { @@ -598,9 +593,7 @@ struct mfma_type intrin_mfma_f32_16x16x32f8bf8::Run(a, b, reg_c); } }; -#endif -#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> struct mfma_type { @@ -644,7 +637,6 @@ struct mfma_type intrin_mfma_f32_16x16x32bf8f8::Run(a, b, reg_c); } }; -#endif template static constexpr auto GetMfma() { @@ -804,9 +795,7 @@ struct MfmaSelector { return MfmaInstr::mfma_f32_16x16x32f8f8; } -#endif -#if defined CK_ENABLE_BF8 template <> static constexpr auto GetMfma() { @@ -818,9 +807,7 @@ struct MfmaSelector { return MfmaInstr::mfma_f32_16x16x32bf8bf8; } -#endif -#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> static constexpr auto GetMfma() { @@ -832,9 +819,7 @@ struct MfmaSelector { return MfmaInstr::mfma_f32_16x16x32f8bf8; } -#endif -#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> static constexpr auto GetMfma() { @@ -846,7 +831,6 @@ struct MfmaSelector { return MfmaInstr::mfma_f32_16x16x32bf8f8; } -#endif static constexpr auto selected_mfma = mfma_type()>{}; @@ -1051,18 +1035,10 @@ struct XdlopsGemm static_assert( is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value -#if defined CK_ENABLE_FP8 - || is_same::value -#endif -#if defined CK_ENABLE_BF8 - || is_same::value -#endif -#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - || (is_same::value && is_same::value) || - (is_same::value && is_same::value) -#endif - , + is_same::value || is_same::value || + is_same::value || + (is_same::value && is_same::value) || + (is_same::value && is_same::value), "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index bf6241e46e..afc066405e 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,10 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_AMD_XDLOPS_HPP -#define CK_AMD_XDLOPS_HPP - -#include "data_type.hpp" +#pragma once namespace ck { @@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> } }; -#if defined CK_ENABLE_FP8 template struct intrin_mfma_f32_32x32x16f8f8; @@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> #endif } }; -#endif -#if defined CK_ENABLE_BF8 template struct intrin_mfma_f32_32x32x16bf8bf8; @@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> #endif } }; -#endif -#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template struct intrin_mfma_f32_32x32x16f8bf8; @@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> #endif } }; -#endif -#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template struct intrin_mfma_f32_32x32x16bf8f8; @@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> #endif } }; -#endif + } // namespace ck -#endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 89100577aa..ceaca27a45 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -9,15 +9,9 @@ namespace ck { using bhalf_t = ushort; using half_t = _Float16; -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -using int4_t = _BitInt(4); -#endif -#if defined CK_ENABLE_FP8 -using f8_t = _BitInt(8); -#endif -#if defined CK_ENABLE_BF8 -using bf8_t = unsigned _BitInt(8); -#endif +using int4_t = _BitInt(4); +using f8_t = _BitInt(8); +using bf8_t = unsigned _BitInt(8); // vector_type template @@ -148,23 +142,19 @@ struct scalar_type }; #endif -#if defined CK_ENABLE_FP8 template <> struct scalar_type { using type = f8_t; static constexpr index_t vector_size = 1; }; -#endif -#if defined CK_ENABLE_BF8 template <> struct scalar_type { using type = bf8_t; static constexpr index_t vector_size = 1; }; -#endif template struct vector_type @@ -968,24 +958,20 @@ using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; // f8 -#if defined CK_ENABLE_FP8 using f8x2_t = typename vector_type::type; using f8x4_t = typename vector_type::type; using f8x8_t = typename vector_type::type; using f8x16_t = typename vector_type::type; using f8x32_t = typename vector_type::type; using f8x64_t = typename vector_type::type; -#endif // bf8 -#if defined CK_ENABLE_BF8 using bf8x2_t = typename vector_type::type; using bf8x4_t = typename vector_type::type; using bf8x8_t = typename vector_type::type; using bf8x16_t = typename vector_type::type; using bf8x32_t = typename vector_type::type; using bf8x64_t = typename vector_type::type; -#endif template struct NumericLimits @@ -1033,7 +1019,6 @@ struct NumericLimits }; #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#if defined CK_ENABLE_FP8 template <> struct NumericLimits { @@ -1056,9 +1041,7 @@ struct NumericLimits __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } }; -#endif -#if defined CK_ENABLE_BF8 template <> struct NumericLimits { @@ -1081,7 +1064,6 @@ struct NumericLimits __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } }; -#endif template struct NumericUtils @@ -1120,22 +1102,18 @@ struct NumericUtils using bitwise_type = uint16_t; }; -#if defined CK_ENABLE_FP8 template <> struct NumericUtils { static constexpr int exp = 4; static constexpr int mant = 3; }; -#endif -#if defined CK_ENABLE_BF8 template <> struct NumericUtils { static constexpr int exp = 5; static constexpr int mant = 2; }; -#endif - +// } // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 217b339b66..b63c82fe9a 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -6,8 +6,6 @@ #include "ck/utility/data_type.hpp" // these conversions are disabled if native conversions available -#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) -#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 namespace ck { // fp8 rounding modes @@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x) } } // namespace ck::utils -#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 -#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index ccbd5db644..aba8baa593 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -95,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } -#if defined CK_ENABLE_FP8 // convert fp32 to fp8 template <> inline __host__ __device__ f8_t type_convert(float x) @@ -173,9 +172,7 @@ inline __host__ __device__ half_t type_convert(f8_t x) return type_convert(type_convert(x)); #endif } -#endif -#if defined CK_ENABLE_BF8 // convert fp32 to bf8 template <> inline __host__ __device__ bf8_t type_convert(float x) @@ -253,7 +250,6 @@ inline __host__ __device__ half_t type_convert(bf8_t x) return type_convert(type_convert(x)); #endif } -#endif // Declare a template function for bf16 conversion using RTN template @@ -316,7 +312,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(h template __host__ __device__ constexpr Y f8_convert_sr(X x); -#if defined CK_ENABLE_FP8 // convert fp32 to fp8 with stochastic rounding template <> inline __host__ __device__ f8_t f8_convert_sr(float x) @@ -365,9 +360,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) return f8_convert_sr(type_convert(x)); #endif } -#endif -#if defined CK_ENABLE_BF8 // convert fp32 to bf8 with stochastic rounding template <> inline __host__ __device__ bf8_t f8_convert_sr(float x) @@ -417,6 +410,5 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) return f8_convert_sr(type_convert(x)); #endif } -#endif } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index ea11fd2e1a..ba21e7a251 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -20,12 +20,8 @@ using F16 = ck::half_t; using BF16 = ck::bhalf_t; using I8 = int8_t; using I32 = int32_t; -#if defined CK_ENABLE_FP8 -using F8 = ck::f8_t; -#endif -#if defined CK_ENABLE_BF8 -using BF8 = ck::bf8_t; -#endif +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; using Empty_Tuple = ck::Tuple<>; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp index 571ff0b672..56ab6be0c3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp @@ -240,11 +240,13 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs); } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) @@ -267,17 +269,23 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); -#endif } +#endif +#if defined(DL_KERNELS) && defined(CK_ENABLE_FP32) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) @@ -306,14 +314,16 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs); } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp index ad2da3364f..5e594366e6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp @@ -98,30 +98,31 @@ struct DeviceOperationInstanceFactory< if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); } +#endif #ifdef CK_ENABLE_FP16 - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index e14df2d750..8ad6ddca9d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -155,7 +155,7 @@ struct DeviceOperationInstanceFactory< std::vector> op_ptrs; #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -180,8 +180,8 @@ struct DeviceOperationInstanceFactory< } #endif #ifdef CK_ENABLE_FP16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -206,8 +206,8 @@ struct DeviceOperationInstanceFactory< } #endif #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -230,8 +230,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 9310e9e57d..f15008349a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -627,8 +627,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); @@ -637,9 +637,8 @@ struct DeviceOperationInstanceFactory && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( @@ -650,8 +649,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS #ifdef CK_ENABLE_FP32 @@ -662,16 +661,15 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( op_ptrs); @@ -680,7 +678,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -698,8 +696,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( @@ -710,9 +708,8 @@ struct DeviceOperationInstanceFactory && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( @@ -723,8 +720,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && @@ -739,8 +736,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( @@ -751,9 +748,8 @@ struct DeviceOperationInstanceFactory && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( @@ -765,7 +761,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -783,8 +779,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( @@ -799,9 +795,8 @@ struct DeviceOperationInstanceFactory && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( @@ -822,8 +817,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && @@ -838,10 +833,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( @@ -856,9 +850,8 @@ struct DeviceOperationInstanceFactory && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( @@ -879,9 +872,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && - is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( op_ptrs); diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index c0f9ba2edc..a3df884eee 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -230,7 +230,6 @@ check_err(const Range& out, return res; } -#if defined CK_ENABLE_FP8 template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, f8_t>), @@ -275,9 +274,7 @@ check_err(const Range& out, } return res; } -#endif -#if defined CK_ENABLE_BF8 template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, bf8_t>), @@ -322,7 +319,6 @@ check_err(const Range& out, } return res; } -#endif } // namespace utils } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 9c1f72eca3..e1952a886c 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -2,44 +2,44 @@ function(add_instance_library INSTANCE_NAME) message("adding instance ${INSTANCE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS ARGN) - set(test 0) - foreach(type IN LISTS DTYPES) + foreach(source IN LISTS ARGN) + set(test 0) + foreach(type IN LISTS DTYPES) if(type MATCHES "fp16") - set(type1 "_f16") + set(type1 "_f16") elseif(type MATCHES "fp32") - set(type1 "_f32") + set(type1 "_f32") elseif(type MATCHES "fp8") - set(type1 "_f8") + set(type1 "_f8") elseif(type MATCHES "bf16") - set(type1 "_b16") + set(type1 "_b16") elseif(type MATCHES "fp64") - set(type1 "_f64") + set(type1 "_f64") elseif(type MATCHES "int8") - set(type1 "_i8") + set(type1 "_i8") endif() #make an exception for reduction kernels - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance") - #if filename matches any selected type, exit type loop and do no exclude the file from the list - set(test 0) - break() + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance" OR ${source} MATCHES "device_image_to_column") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) - #if filename contains a type which doesn't match any selected type, mark it for removal - set(test 1) + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) + #if filename contains a type which doesn't match any selected type, mark it for removal + set(test 1) endif() - endforeach() - if(test EQUAL 1) + endforeach() + if(test EQUAL 1) message("removing instance ${source} ") list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() + endif() + endforeach() endif() foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") - message("removing dl instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") + message("removing dl instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list @@ -49,8 +49,10 @@ function(add_instance_library INSTANCE_NAME) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(${INSTANCE_NAME}) set(result 0) + message("add_instance_library ${INSTANCE_NAME}") + else() + message("skip_instance_libary ${INSTANCE_NAME}") endif() - #message("add_instance_library returns ${result}") set(result ${result} PARENT_SCOPE) endfunction(add_instance_library INSTANCE_NAME) @@ -58,65 +60,70 @@ endfunction(add_instance_library INSTANCE_NAME) file(GLOB dir_list LIST_DIRECTORIES true *) set(CK_DEVICE_INSTANCES) FOREACH(subdir_path ${dir_list}) -set(target_dir) -IF(IS_DIRECTORY "${subdir_path}") - set(cmake_instance) - file(READ "${subdir_path}/CMakeLists.txt" cmake_instance) - set(add_inst 0) - if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") + set(target_dir) + IF(IS_DIRECTORY "${subdir_path}") + set(cmake_instance) + file(READ "${subdir_path}/CMakeLists.txt" cmake_instance) + set(add_inst 0) + if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") message("fp8 instance found!") set(add_inst 1) - endif() - if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") + endif() + if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") message("fp16 instance found!") set(add_inst 1) - endif() - if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") + endif() + if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") message("fp32 instance found!") set(add_inst 1) - endif() - if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") + endif() + if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") message("fp64 instance found!") set(add_inst 1) - endif() - if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16") + endif() + if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16") message("bf16 instance found!") set(add_inst 1) - endif() - if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") + endif() + if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") message("int8 instance found!") set(add_inst 1) - endif() - if(NOT "${cmake_instance}" MATCHES "_fp8" OR - NOT "${cmake_instance}" MATCHES "_f8" OR - NOT "${cmake_instance}" MATCHES "_fp16" OR - NOT "${cmake_instance}" MATCHES "_f16" OR - NOT "${cmake_instance}" MATCHES "_fp32" OR - NOT "${cmake_instance}" MATCHES "_f32" OR - NOT "${cmake_instance}" MATCHES "_fp64" OR - NOT "${cmake_instance}" MATCHES "_f64" OR - NOT "${cmake_instance}" MATCHES "_bf16" OR - NOT "${cmake_instance}" MATCHES "_int8" OR - NOT "${cmake_instance}" MATCHES "_i8" OR - NOT "${cmake_instance}" MATCHES "_int4" OR - NOT DEFINED DTYPES) - message("instance should be built for all types!") - set(add_inst 1) - endif() - if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8") - message("quantization instances will not be built!") - set(add_inst 0) - endif() - if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS) - message("Found only dl instances, but DL_KERNELS is not set. Skipping.") + endif() + if(NOT ("${cmake_instance}" MATCHES "_fp8" OR + "${cmake_instance}" MATCHES "_f8" OR + "${cmake_instance}" MATCHES "_fp16" OR + "${cmake_instance}" MATCHES "_f16" OR + "${cmake_instance}" MATCHES "_fp32" OR + "${cmake_instance}" MATCHES "_f32" OR + "${cmake_instance}" MATCHES "_fp64" OR + "${cmake_instance}" MATCHES "_f64" OR + "${cmake_instance}" MATCHES "_bf16" OR + "${cmake_instance}" MATCHES "_int8" OR + "${cmake_instance}" MATCHES "_i8" OR + "${cmake_instance}" MATCHES "_int4")) + message("instance should be built for all types!") + set(add_inst 1) + endif() + if(NOT DEFINED DTYPES) + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8")) + message("quantization instances will not be built!") set(add_inst 0) - endif() - if(add_inst EQUAL 1) - get_filename_component(target_dir ${subdir_path} NAME) - add_subdirectory(${target_dir}) - list(APPEND CK_DEVICE_INSTANCES $) - endif() -ENDIF() + endif() + if(("${cmake_instance}" MATCHES "ONLY DL_KERNELS") AND (NOT DEFINED DL_KERNELS)) + message("Found only dl instances, but DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() + if((add_inst EQUAL 1)) + get_filename_component(target_dir ${subdir_path} NAME) + add_subdirectory(${target_dir}) + list(APPEND CK_DEVICE_INSTANCES $) + message("add_instance_directory ${subdir_path}") + else() + message("skip_instance_directory ${subdir_path}") + endif() + ENDIF() ENDFOREACH() add_library(device_operations STATIC ${CK_DEVICE_INSTANCES}) @@ -158,11 +165,11 @@ target_compile_options(device_operations PRIVATE # install(TARGETS device_operations LIBRARY DESTINATION lib) rocm_install(TARGETS device_operations - EXPORT device_operationsTargets) + EXPORT device_operationsTargets) rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) rocm_install(EXPORT device_operationsTargets - FILE composable_kerneldevice_operationsTargets.cmake - NAMESPACE composable_kernel:: - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel + FILE composable_kerneldevice_operationsTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 0927cf225d..836e671bf2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -1,11 +1,10 @@ -add_instance_library(device_grouped_conv3d_bwd_data_instance +set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp @@ -13,5 +12,11 @@ add_instance_library(device_grouped_conv3d_bwd_data_instance wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp -) + wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp) + +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_BWD_DATA + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp) +endif() + +add_instance_library(device_grouped_conv3d_bwd_data_instance ${GROUPED_CONV3D_BWD_DATA}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index c3cc4cb054..bada661028 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -1,33 +1,32 @@ -add_instance_library(device_grouped_conv3d_fwd_instance +set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp - - xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp -) + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp) + +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp) +endif() + +add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index 51a42c3d8d..3b48954d22 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,18 +1,10 @@ set(GROUPED_GEMM_FIXED_NK_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp) -endif() - -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) - list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp) - list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp) -endif() - -if((DTYPES MATCHES "int8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) - list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp) - list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp) -endif() +list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 8d9a47cd6b..e1d8afc2a4 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -25,8 +25,6 @@ set(PROFILER_SOURCES profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp - profile_contraction_bilinear.cpp - profile_contraction_scale.cpp profile_grouped_conv_bwd_data.cpp profile_conv_tensor_rearrange.cpp ) @@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) endif() +if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) + list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) +endif() + set(PROFILER_EXECUTABLE ckProfiler) add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) @@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) @@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) + +if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +endif() + + + if(DL_KERNELS) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index bd1016f286..6ed7cf5e48 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -86,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#ifdef CK_ENABLE_FP8 - using F8 = ck::f8_t; -#endif -#ifdef CK_ENABLE_BF8 - using BF8 = ck::bf8_t; -#endif + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; using namespace ck::tensor_layout::convolution; @@ -141,59 +137,59 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } - else if(data_type == ConvDataType::F16_F16_F16) + if(data_type == ConvDataType::F16_F16_F16) { return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } - else if(data_type == ConvDataType::BF16_F32_BF16) + if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } - else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } - else if(data_type == ConvDataType::F16_F16_F16) + if(data_type == ConvDataType::F16_F16_F16) { return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } - else if(data_type == ConvDataType::BF16_F32_BF16) + if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } - else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } - else if(data_type == ConvDataType::F16_F16_F16) + if(data_type == ConvDataType::F16_F16_F16) { return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } - else if(data_type == ConvDataType::BF16_F32_BF16) + if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } - else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } - else if(data_type == ConvDataType::F16_F16_F16) + if(data_type == ConvDataType::F16_F16_F16) { return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } - else if(data_type == ConvDataType::BF16_F32_BF16) + if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); @@ -204,22 +200,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); } } - else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } - else if(data_type == ConvDataType::F16_F16_F16) + if(data_type == ConvDataType::F16_F16_F16) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } - else if(data_type == ConvDataType::BF16_F32_BF16) + if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } - else if(data_type == ConvDataType::F16_F16_F16_BF8_F8) + if(data_type == ConvDataType::F16_F16_F16_BF8_F8) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4b2c7bbf38..265f428b55 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,40 +11,40 @@ function(add_test_executable TEST_NAME) message("adding test ${TEST_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS ARGN) - set(test 0) - foreach(type IN LISTS DTYPES) - if(type MATCHES "fp16") - set(type1 "_f16") - elseif(type MATCHES "fp32") - set(type1 "_f32") - elseif(type MATCHES "fp8") - set(type1 "_f8") - elseif(type MATCHES "bf16") - set(type1 "_b16") - elseif(type MATCHES "fp64") - set(type1 "_f64") - elseif(type MATCHES "int8") - set(type1 "_i8") - endif() - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") - #if filename matches any selected type, exit type loop and do no exclude the file from the list - set(test 0) - break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) + foreach(source IN LISTS ARGN) + set(test 0) + foreach(type IN LISTS DTYPES) + if(type MATCHES "fp16") + set(type1 "_f16") + elseif(type MATCHES "fp32") + set(type1 "_f32") + elseif(type MATCHES "fp8") + set(type1 "_f8") + elseif(type MATCHES "bf16") + set(type1 "_b16") + elseif(type MATCHES "fp64") + set(type1 "_f64") + elseif(type MATCHES "int8") + set(type1 "_i8") + endif() + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() + elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) + endif() + endforeach() + if(test EQUAL 1) + message("removing test ${source} ") + list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - if(test EQUAL 1) - message("removing test ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() - endif() - foreach(source IN LISTS ARGN) + endif() + foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") list(REMOVE_ITEM ARGN "${source}") @@ -70,38 +70,38 @@ function(add_gtest_executable TEST_NAME) message("adding gtest ${TEST_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS ARGN) - set(test 0) - foreach(type IN LISTS DTYPES) - if(type MATCHES "fp16") - set(type1 "_f16") - elseif(type MATCHES "fp32") - set(type1 "_f32") - elseif(type MATCHES "fp8") - set(type1 "_f8") - elseif(type MATCHES "bf16") - set(type1 "_b16") - elseif(type MATCHES "fp64") - set(type1 "_f64") - elseif(type MATCHES "int8") - set(type1 "_i8") - endif() - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") - #if filename matches any selected type, exit type loop and do no exclude the file from the list - set(test 0) - break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) + foreach(source IN LISTS ARGN) + set(test 0) + foreach(type IN LISTS DTYPES) + if(type MATCHES "fp16") + set(type1 "_f16") + elseif(type MATCHES "fp32") + set(type1 "_f32") + elseif(type MATCHES "fp8") + set(type1 "_f8") + elseif(type MATCHES "bf16") + set(type1 "_b16") + elseif(type MATCHES "fp64") + set(type1 "_f64") + elseif(type MATCHES "int8") + set(type1 "_i8") + endif() + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() + elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) + endif() + endforeach() + if(test EQUAL 1) + message("removing gtest ${source} ") + list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - if(test EQUAL 1) - message("removing gtest ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() endif() foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt index 1f6e0ed341..a86e72fddb 100644 --- a/test/contraction/CMakeLists.txt +++ b/test/contraction/CMakeLists.txt @@ -1,11 +1,13 @@ -add_gtest_executable(test_contraction test_contraction.cpp) -target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) - target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - set(target 1) - endif() + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) + add_gtest_executable(test_contraction test_contraction.cpp) + target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) + add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) + target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) + set(target 1) + endif() + endif() endforeach() diff --git a/test/conv_tensor_rearrange/CMakeLists.txt b/test/conv_tensor_rearrange/CMakeLists.txt index f6ad263242..05ca4a9ffb 100644 --- a/test/conv_tensor_rearrange/CMakeLists.txt +++ b/test/conv_tensor_rearrange/CMakeLists.txt @@ -1,4 +1,5 @@ add_gtest_executable(test_conv_tensor_rearrange test_conv_tensor_rearrange.cpp) target_link_libraries(test_conv_tensor_rearrange PRIVATE utility device_image_to_column_instance device_column_to_image_instance) + add_gtest_executable(test_conv_tensor_rearrange_interface test_conv_tensor_rearrange_interface.cpp) target_link_libraries(test_conv_tensor_rearrange_interface PRIVATE utility)