From 4bea06a519ea84f58eea274b666128304e336b70 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 7 Aug 2023 14:56:10 -0700 Subject: [PATCH] Allow building CK for specific data types and split off last remaining DL instances. (#830) * properly split conv_nd_bwd_data instances * split conv2d_fwd instance data types * split the gemm, conv2d_fwd and batched_gemm_softamx_gemm * split the tests by data types where possible * filter examples by DTYPES * split few remaining examples by DTYPES * filter most instances by DTYPES * add new lines at end of headers, fix grouped_gemm profiler * fix syntax * split the ckprofiler instances by DTYPES * split the conv2d and quantization DL and XDL instances * fix the splitting of conv2d DL instances * split softmax and pool_fwd tests for fp16 and fp32 types * fix syntax * fix the dl_int8 quantization instances isolation [ROCm/composable_kernel commit: 08eb17692978f2af709deb59c98aae6db4c82b6b] --- example/01_gemm/CMakeLists.txt | 78 ++++++++-------- example/02_gemm_bilinear/CMakeLists.txt | 2 + example/03_gemm_bias_relu/CMakeLists.txt | 2 + .../04_gemm_add_add_fastgelu/CMakeLists.txt | 30 ++++--- example/09_convnd_fwd/CMakeLists.txt | 34 +++++-- .../CMakeLists.txt | 24 +++-- example/13_pool2d_fwd/CMakeLists.txt | 9 +- example/14_gemm_quantization/CMakeLists.txt | 4 +- example/15_grouped_gemm/CMakeLists.txt | 38 ++++---- .../CMakeLists.txt | 54 +++++------ example/17_convnd_bwd_data/CMakeLists.txt | 8 +- example/18_batched_gemm_reduce/CMakeLists.txt | 2 + .../20_grouped_conv_bwd_weight/CMakeLists.txt | 26 +++--- example/21_gemm_layernorm/CMakeLists.txt | 2 + example/22_cgemm/CMakeLists.txt | 25 +++--- example/24_batched_gemm/CMakeLists.txt | 28 +++--- example/25_gemm_bias_e_permute/CMakeLists.txt | 6 +- example/26_contraction/CMakeLists.txt | 13 +-- example/27_layernorm/CMakeLists.txt | 6 +- .../CMakeLists.txt | 4 +- .../CMakeLists.txt | 8 +- .../CMakeLists.txt | 40 +++++---- example/31_batched_gemm_gemm/CMakeLists.txt | 17 ++-- .../CMakeLists.txt | 36 +++++--- example/35_splitK_gemm/CMakeLists.txt | 27 +++--- .../CMakeLists.txt | 4 +- .../CMakeLists.txt | 4 +- example/39_permute/CMakeLists.txt | 16 ++-- .../40_conv2d_fwd_quantization/CMakeLists.txt | 30 +++---- .../41_grouped_conv_conv_fwd/CMakeLists.txt | 16 +++- example/42_groupnorm/CMakeLists.txt | 8 +- .../CMakeLists.txt | 8 +- example/44_elementwise_permute/CMakeLists.txt | 6 +- example/46_gemm_add_multiply/CMakeLists.txt | 8 +- example/48_pool3d_fwd/CMakeLists.txt | 5 +- example/49_maxpool2d_bwd/CMakeLists.txt | 12 ++- example/50_put_element/CMakeLists.txt | 4 +- .../gpu/batched_gemm.hpp | 35 +++++--- .../gpu/batched_gemm_add_relu_gemm_add.hpp | 3 +- .../gpu/batched_gemm_bias_permute.hpp | 3 +- ...batched_gemm_bias_softmax_gemm_permute.hpp | 12 ++- .../gpu/batched_gemm_gemm.hpp | 3 +- .../gpu/batched_gemm_multi_d.hpp | 13 +-- .../gpu/batched_gemm_softmax_gemm.hpp | 3 +- .../gpu/batched_gemm_softmax_gemm_permute.hpp | 11 ++- .../gpu/contraction_bilinear.hpp | 14 +-- .../gpu/contraction_scale.hpp | 14 +-- .../gpu/convolution_backward_data.hpp | 89 +++++++++++++------ .../gpu/convolution_forward.hpp | 26 ++++-- .../gpu/elementwise_normalization.hpp | 3 +- .../tensor_operation_instance/gpu/gemm.hpp | 4 + .../gpu/gemm_add_relu_add_layernorm.hpp | 3 +- .../gpu/gemm_bilinear.hpp | 3 +- .../gpu/gemm_fastgelu.hpp | 3 +- .../gpu/gemm_streamk.hpp | 3 +- .../gpu/grouped_gemm.hpp | 3 +- .../gpu/normalization.hpp | 17 ++-- .../gpu/pool2d_fwd.hpp | 17 ++-- .../gpu/pool3d_fwd.hpp | 17 ++-- .../gpu/quantization/gemm_quantization.hpp | 15 +++- ...n_bias_forward_perchannel_quantization.hpp | 13 ++- ...ion_bias_forward_perlayer_quantization.hpp | 13 ++- ...lution_forward_perchannel_quantization.hpp | 11 ++- ...volution_forward_perlayer_quantization.hpp | 11 ++- .../tensor_operation_instance/gpu/softmax.hpp | 10 ++- .../gpu/batched_gemm/CMakeLists.txt | 44 +++++---- .../CMakeLists.txt | 2 + .../batched_gemm_bias_permute/CMakeLists.txt | 3 +- .../gpu/batched_gemm_gemm/CMakeLists.txt | 2 + .../gpu/batched_gemm_reduce/CMakeLists.txt | 3 +- .../batched_gemm_softmax_gemm/CMakeLists.txt | 3 +- .../CMakeLists.txt | 16 ++-- .../gpu/contraction_bilinear/CMakeLists.txt | 24 ++--- .../gpu/contraction_scale/CMakeLists.txt | 24 ++--- .../gpu/conv2d_bwd_data/CMakeLists.txt | 12 ++- ...wd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp | 3 +- ...wd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp | 3 +- ...d_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp | 3 +- ..._data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp | 3 +- .../gpu/conv2d_fwd/CMakeLists.txt | 23 +++-- ...d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp | 3 +- ...2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp | 3 +- ...2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 3 +- ...d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp | 3 +- .../elementwise_normalization/CMakeLists.txt | 2 + .../gpu/gemm_add_add_fastgelu/CMakeLists.txt | 2 + .../gpu/gemm_add_fastgelu/CMakeLists.txt | 2 + .../CMakeLists.txt | 2 + .../gpu/gemm_bilinear/CMakeLists.txt | 2 + .../gpu/gemm_fastgelu/CMakeLists.txt | 2 + .../gpu/gemm_streamk/CMakeLists.txt | 2 + .../gpu/grouped_gemm/CMakeLists.txt | 2 + .../gpu/grouped_gemm_fastgelu/CMakeLists.txt | 2 + .../gpu/normalization/CMakeLists.txt | 20 +++-- .../gpu/pool_fwd/CMakeLists.txt | 24 ++--- .../gpu/quantization/CMakeLists.txt | 38 ++++---- .../gpu/softmax/CMakeLists.txt | 16 ++-- profiler/src/CMakeLists.txt | 48 +++++----- profiler/src/profile_grouped_gemm.cpp | 4 +- test/CMakeLists.txt | 2 +- test/batched_gemm/CMakeLists.txt | 35 ++++---- test/batched_gemm_gemm/CMakeLists.txt | 12 +-- test/batched_gemm_multi_d/CMakeLists.txt | 5 +- test/batched_gemm_reduce/CMakeLists.txt | 10 ++- test/batched_gemm_softmax_gemm/CMakeLists.txt | 12 +-- .../CMakeLists.txt | 34 +++---- test/elementwise_normalization/CMakeLists.txt | 13 ++- test/gemm_layernorm/CMakeLists.txt | 2 + test/gemm_reduce/CMakeLists.txt | 8 +- test/grouped_gemm/CMakeLists.txt | 2 + test/normalization/CMakeLists.txt | 35 ++++---- test/pool_fwd/test_avg_pool2d_fwd.cpp | 4 + test/pool_fwd/test_avg_pool3d_fwd.cpp | 6 +- test/pool_fwd/test_max_pool2d_fwd.cpp | 6 +- test/pool_fwd/test_max_pool3d_fwd.cpp | 6 +- test/softmax/test_softmax_rank3.cpp | 5 +- test/softmax/test_softmax_rank4.cpp | 5 +- 117 files changed, 976 insertions(+), 590 deletions(-) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 42a3499385..7ce866931c 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,31 +1,44 @@ -add_custom_target(example_gemm_dl) +if(DL_KERNELS) + add_custom_target(example_gemm_dl) -add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) -add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) + add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_fp32) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_fp16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_int8) + endif() -add_dependencies(example_gemm_dl example_gemm_dl_fp32) -add_dependencies(example_gemm_dl example_gemm_dl_fp16) - -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int8) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_int4) + endif(USE_BITINT_EXTENSION_INT4) endif() -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int4) -endif(USE_BITINT_EXTENSION_INT4) - - add_custom_target(example_gemm_xdl) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) + add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) + add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) -add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) -add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) + if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_custom_target(example_gemm_wmma) + add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) + add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) + endif() -add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) -add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) +endif() + +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) +endif() if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) @@ -37,22 +50,17 @@ if(USE_BITINT_EXTENSION_INT4) add_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) -# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) - -add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) - -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed + add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) endif() add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_f8) +if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) + if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_f8) + endif() endif() diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 7ea4564d1b..9989e45e0c 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) @@ -15,3 +16,4 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() +endif() diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 2f5cba924d..a247a052cb 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() +endif() diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 447fa9871b..15ec62c89f 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -3,22 +3,26 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_gemm_add_add_fastgelu_xdl) - - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - if(USE_BITINT_EXTENSION_INT4) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + endif() set(target 1) endif() endforeach() \ No newline at end of file diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 90104c1630..1af1e6c858 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -2,16 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) + endif() # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) + if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) + endif() set(target 1) endif() endforeach() -add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) -add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) -add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) +if(DL_KERNELS) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) + endif() +endif() diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 9577b4569a..e7d941ae6b 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -3,14 +3,22 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_convnd_fwd_reduce_xdl) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index db09c03321..d0f356757b 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ -add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) -add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) - +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) +endif() diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index ca06265e31..3b3ad80dd8 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,6 +1,8 @@ if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) # dlops -add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +if(DL_KERNELS) + add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +endif() # xdlops list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 9df256c38d..dca60b0e73 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,21 +1,25 @@ add_custom_target(example_grouped_gemm_xdl) - -add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) -add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) -add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) -add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) -add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) -add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) - - -add_dependencies(example_grouped_gemm_xdl - example_grouped_gemm_xdl_fp32 - example_grouped_gemm_xdl_fp16 - example_grouped_gemm_xdl_bfp16 - example_grouped_gemm_xdl_int8 - example_grouped_gemm_multiple_d_dl_fp16 - example_grouped_gemm_xdl_splitk_fp16) - +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) + add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) + add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) + add_dependencies(example_grouped_gemm_xdl + example_grouped_gemm_xdl_fp16 + example_grouped_gemm_multiple_d_dl_fp16 + example_grouped_gemm_xdl_splitk_fp16) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) +endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index a42b427c61..00786d34a3 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -6,33 +6,33 @@ foreach(gpu IN LISTS GPU_TARGETS) add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - - add_dependencies(example_gemm_reduce_xdl_max - example_gemm_max_xdl_bf16 - example_gemm_max_xdl_fp16 - example_gemm_max_xdl_fp32 - example_gemm_max_xdl_int8) - - add_dependencies(example_gemm_reduce_xdl_mean_meansquare - example_gemm_mean_meansquare_xdl_fp16 - example_gemm_mean_meansquare_xdl_fp32 - example_gemm_mean_meansquare_xdl_bf16 - example_gemm_add_addsquare_xdl_int8) - - add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) + add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) + add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) + add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) + endif() + add_dependencies(example_gemm_reduce_xdl example_gemm_reduce_xdl_mean_meansquare example_gemm_reduce_xdl_max diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 8ab9f37bb3..e187bd4337 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -7,5 +8,8 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) -target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) + if(DL_KERNELS) + add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) + target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) + endif() +endif() diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 94ed129dc0..a1bb398af0 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() +endif() diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index db170decce..d649567ed2 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -3,18 +3,22 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_grouped_conv_bwd_weight) - - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16 - example_grouped_conv_bwd_weight_xdl_bf16) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) + endif() set(target 1) endif() endforeach() -add_custom_target(example_grouped_conv_bwd_weight_dl) - -add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) - -add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + if(DL_KERNELS) + add_custom_target(example_grouped_conv_bwd_weight_dl) + add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) + add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) + endif() +endif() \ No newline at end of file diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index ff870e7c6b..6a6735efd6 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -9,3 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() +endif() diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 1564561156..854f07fda6 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,16 +1,21 @@ add_custom_target(example_cgemm_xdl) -add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) -add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) -add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) - -add_dependencies(example_cgemm_xdl - example_cgemm_xdl_bf16 - example_cgemm_xdl_fp16 - example_cgemm_xdl_fp32 - example_cgemm_xdl_int8) - +add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) +endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 7962576e87..48a3b58ff5 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,16 +1,20 @@ add_custom_target(example_batched_gemm_xdl) - -add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) -add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) -add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) - -add_dependencies(example_batched_gemm_xdl - example_batched_gemm_xdl_fp32 - example_batched_gemm_xdl_fp16 - example_batched_gemm_xdl_bfp16 - example_batched_gemm_xdl_int8) - +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) +endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) diff --git a/example/25_gemm_bias_e_permute/CMakeLists.txt b/example/25_gemm_bias_e_permute/CMakeLists.txt index cbc3c007bc..eb274b2338 100644 --- a/example/25_gemm_bias_e_permute/CMakeLists.txt +++ b/example/25_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,4 @@ -add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) -add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) + add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) +endif() diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index c58751f0dd..6cab88b13e 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,5 +1,8 @@ -add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) - -add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) + add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) +endif() +if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) + add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) +endif() diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt index 94c23ce774..9cb2cd0766 100644 --- a/example/27_layernorm/CMakeLists.txt +++ b/example/27_layernorm/CMakeLists.txt @@ -1,2 +1,4 @@ -add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) -add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) + add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) +endif() diff --git a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt index 44ab16894c..2fda1f62a9 100644 --- a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt +++ b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt @@ -1 +1,3 @@ -add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) +endif() diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 32a87dd200..09c3e6c608 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,7 @@ -add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) + if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) + endif() endif() diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 1c07538b09..e37413c098 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -5,23 +5,29 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) add_custom_target(example_grouped_conv_fwd_multiple_d) - - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) + add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) endif() # USE_BITINT_EXTENSION_INT4 - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) + set(target 1) endif() endforeach() @@ -29,8 +35,12 @@ endforeach() set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) + endif() set(target 1) endif() endforeach() diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 83989bcd97..2074520f8c 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -3,10 +3,15 @@ list(APPEND gpu_list2 gfx908 gfx90a) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) @@ -15,5 +20,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endforeach() if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) + endif() endif() diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 8d9aaec85a..0463bf6bd3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,16 +1,24 @@ -add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) -add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) + add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) + add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) + add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +endif() add_custom_target(example_gemm_scale_softmax_gemm) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) -add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) +endif() diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 57ac33fc9a..fa66a5a520 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -3,17 +3,22 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_splitK_gemm_xdl) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - - add_dependencies(example_splitK_gemm_xdl - example_splitK_gemm_xdl_fp32 - example_splitK_gemm_xdl_fp16 - example_splitK_gemm_xdl_bfp16 - example_splitK_gemm_xdl_int8) - + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt index a9be3a7108..36bb5720d5 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt +++ b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -1 +1,3 @@ -add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) +endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index de48093ac2..3821f8aaca 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS) add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() +endif() diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index 573ad7239e..5b43de9725 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,9 +1,11 @@ -add_custom_target(example_permute) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_custom_target(example_permute) -add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) -add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) -add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) + add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) + add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) + add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) -add_dependencies(example_permute example_permute_1xHxW_fp16) -add_dependencies(example_permute example_permute_NxHxW_fp16) -add_dependencies(example_permute example_permute_HxWx4_fp16) + add_dependencies(example_permute example_permute_1xHxW_fp16) + add_dependencies(example_permute example_permute_NxHxW_fp16) + add_dependencies(example_permute example_permute_HxWx4_fp16) +endif() diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 1962b6013b..55464957ac 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -10,21 +10,19 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -# Conv perlayer quantization -add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) -# Conv perchannel quantization -add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) - -# Conv + bias + relu perlayer quantization -add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) - -# Conv + bias + relu perchannel quantization -add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) - -# Conv + bias + tanh perlayer quantization -add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - -# Conv + bias + tanh perchannel quantization -add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) + if(DL_KERNELS) + # Conv perlayer quantization + add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) + # Conv perchannel quantization + add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) + # Conv + bias + relu perlayer quantization + add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) + # Conv + bias + relu perchannel quantization + add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) + # Conv + bias + tanh perlayer quantization + add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) + # Conv + bias + tanh perchannel quantization + add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) + endif() endif() \ No newline at end of file diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index ae251e88d2..085d90a9e9 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -3,9 +3,15 @@ list(APPEND gpu_list2 gfx908 gfx90a) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) @@ -14,5 +20,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endforeach() if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) + endif() endif() diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt index e8c306ac58..bc2246a2bf 100644 --- a/example/42_groupnorm/CMakeLists.txt +++ b/example/42_groupnorm/CMakeLists.txt @@ -1,3 +1,5 @@ -add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) -add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) -add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) + add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) + add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) +endif() diff --git a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt index c29f18f162..7e070f5357 100644 --- a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt +++ b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,6 @@ -add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) -add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) +endif() diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index 0e0091a986..877a820315 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -1,2 +1,4 @@ -add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) -add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) + add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) +endif() diff --git a/example/46_gemm_add_multiply/CMakeLists.txt b/example/46_gemm_add_multiply/CMakeLists.txt index bfe057e8da..cf7d81f895 100644 --- a/example/46_gemm_add_multiply/CMakeLists.txt +++ b/example/46_gemm_add_multiply/CMakeLists.txt @@ -1,2 +1,6 @@ -add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) -add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + if(DL_KERNELS) + add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) + endif() + add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) +endif() diff --git a/example/48_pool3d_fwd/CMakeLists.txt b/example/48_pool3d_fwd/CMakeLists.txt index 5d58d6a0b1..f821f2c97b 100644 --- a/example/48_pool3d_fwd/CMakeLists.txt +++ b/example/48_pool3d_fwd/CMakeLists.txt @@ -1,2 +1,3 @@ -add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) - +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) +endif() diff --git a/example/49_maxpool2d_bwd/CMakeLists.txt b/example/49_maxpool2d_bwd/CMakeLists.txt index b29cf9ccbc..fe98b7c997 100644 --- a/example/49_maxpool2d_bwd/CMakeLists.txt +++ b/example/49_maxpool2d_bwd/CMakeLists.txt @@ -1,3 +1,9 @@ -add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) -add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) -add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) +endif() diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt index 1b0020ebcf..eca410008f 100644 --- a/example/50_put_element/CMakeLists.txt +++ b/example/50_put_element/CMakeLists.txt @@ -1 +1,3 @@ -add_example_executable(example_put_element_fp16 put_element_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_put_element_fp16 put_element_fp16.cpp) +endif() diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp index c3c8c0e5a7..a89358538e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef __bf16__ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( std::vector>>& @@ -36,7 +36,8 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( std::vector>>& instances); - +#endif +#ifdef __fp16__ void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( std::vector>>& @@ -56,7 +57,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( std::vector>>& instances); - +#endif +#ifdef __fp32__ void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( std::vector>>& @@ -76,7 +78,8 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( std::vector>>& instances); - +#endif +#ifdef __int8__ void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( std::vector>>& instances); - +#endif template > op_ptrs; - +#ifdef __fp32__ if constexpr(is_same_v && is_same_v && is_same_v) { @@ -176,8 +179,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef __fp16__ + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -200,8 +205,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef __bf16__ + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -224,8 +231,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef __int8__ + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -248,7 +257,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef __bf16__ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( std::vector>>& instances); - +#endif template > op_ptrs; - +#ifdef __fp16__ if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && Acc0BiasDataType::Size() == 1 && @@ -164,6 +165,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef __bf16__ else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && Acc0BiasDataType::Size() == 1 && @@ -180,6 +183,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp index 28ccf61a3a..e57a4206d5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef __fp16__ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( std::vector>>& instances); - +#endif +#ifdef __int8__ void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( std::vector>>& instances); - +#endif template > op_ptrs; - +#ifdef __fp16__ if constexpr(is_same_v && is_same_v && is_same_v) { @@ -295,6 +296,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -327,7 +330,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef __bf16__ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( std::vector>>& instances); +#endif template > op_ptrs; - +#ifdef __fp16__ if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef __bf16__ else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index 2ed8255f6d..da391feef3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef __fp32__ // float void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances); - +#endif +#ifdef __fp64__ // double void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( std::vector>>& instances); - +#endif // Contraction + Bilinear template > op_ptrs; - +#ifdef __fp32__ if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v) { @@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef __fp64__ // double void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( std::vector>>& instances); - +#endif // Contraction + Scale template > op_ptrs; - +#ifdef __fp32__ if constexpr(is_same_v && is_same_v && is_same_v) { @@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef __fp16__ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef __fp32__ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( std::vector>>& instances); +#endif #ifdef __int8__ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( std::vector>>& instances); #endif +#ifdef __bf16__ // conv2d backward data void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef __fp16__ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef __fp32__ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( std::vector>>& instances); +#endif #ifdef __int8__ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector>>& instances); #endif +#ifdef DL_KERNELS +#ifdef __fp16__ // conv2d dl void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef __fp32__ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( std::vector>>& instances); +#endif #ifdef __int8__ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( std::vector>>& instances); #endif +#endif +#ifdef __bf16__ // conv3d backward data void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef __fp16__ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef __fp32__ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( std::vector>>& instances); +#endif #ifdef __int8__ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( std::vector && is_same_v && - is_same_v) +#ifdef __fp16__ + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef __bf16__ + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); } +#endif #ifdef __int8__ - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); } @@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); +#ifdef DL_KERNELS add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); +#endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#ifdef __fp16__ + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); +#ifdef DL_KERNELS add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); +#endif } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef __bf16__ + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); } +#endif #ifdef __int8__ - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); +#ifdef DL_KERNELS add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); +#endif } #endif } @@ -286,20 +314,23 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#ifdef __fp16__ + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef __bf16__ + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); } +#endif #ifdef __int8__ - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp index dd8c1987cf..af4c2aee66 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp @@ -18,11 +18,17 @@ namespace device { namespace instance { // conv2d forward +#ifdef __fp16__ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( std::vector>>& instances); - +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances); +#endif +#ifdef __bf16__ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( std::vector>>& instances); - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( - std::vector>>& - instances); - +#endif +#ifdef __fp32__ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef __int8__ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector>>& instances); +#endif template && is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs); } +#endif +#ifdef __bf16__ else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); } +#endif +#ifdef __int8__ else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); } +#endif } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp index a1c006cf62..b863243c1b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - +#ifdef __fp16__ namespace ck { namespace tensor_operation { namespace device { @@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -388,6 +389,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } +#endif +#ifdef __bf16__ else if constexpr(is_same_v && is_same_v && is_same_v) { @@ -412,6 +415,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); } } +#endif #ifdef __int8__ else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp index de21f325a6..4b11a9b262 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp @@ -9,7 +9,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - +#ifdef __fp16__ namespace ck { namespace tensor_operation { namespace device { @@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory>>& @@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory>>&); @@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances( void add_device_normalization_rank_5_3_f16_instances( std::vector>>&); - +#endif +#ifdef __fp32__ // FP32 void add_device_normalization_rank_2_1_f32_instances( std::vector>>&); @@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances( void add_device_normalization_rank_5_3_f32_instances( std::vector>>&); - +#endif template > op_ptrs; - +#ifdef __fp16__ if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef __fp32__ + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(Rank == 2 && NumReduceDim == 1) { @@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory>>&); - +#endif +#ifdef __fp32__ // FP32 void add_device_pool2d_fwd_nhwc_f32_instances( std::vector< @@ -50,7 +51,7 @@ void add_device_pool2d_fwd_nhwc_f32_instances( void add_device_pool2d_fwd_nhwc_index_f32_instances( std::vector< std::unique_ptr>>&); - +#endif template > op_ptrs; - +#ifdef __fp16__ if constexpr(is_same_v && is_same_v && is_same_v) { @@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef __fp32__ + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(OutputIndex && ReduceOpId == MaxOp) { @@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory>>&); - +#endif +#ifdef __fp32__ // FP32 void add_device_pool3d_fwd_ndhwc_f32_instances( std::vector< @@ -50,7 +51,7 @@ void add_device_pool3d_fwd_ndhwc_f32_instances( void add_device_pool3d_fwd_ndhwc_index_f32_instances( std::vector< std::unique_ptr>>&); - +#endif template > op_ptrs; - +#ifdef __fp16__ if constexpr(is_same_v && is_same_v && is_same_v) { @@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef __fp32__ + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(OutputIndex && ReduceOpId == MaxOp) { @@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif // Layout(A, B, C) = [Col, Row, Row] void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); } } @@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); } } @@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); } } @@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); } } @@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( std::vector< std::unique_ptr) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); } } @@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); } } @@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( std::vector< std::unique_ptr) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); } } @@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); } } @@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_perchannel_quantization_int8_instances( std::vector) { +#ifdef DL_KERNELS add_device_conv2d_dl_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_perchannel_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(op_ptrs); } } @@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_perlayer_quantization_int8_instances( std::vector) { +#ifdef DL_KERNELS add_device_conv2d_dl_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_perlayer_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(op_ptrs); } } @@ -144,3 +148,4 @@ struct DeviceOperationInstanceFactory> op_ptrs; - +#ifdef __fp16__ if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { @@ -65,8 +65,10 @@ struct DeviceOperationInstanceFactory && std::is_same_v && - std::is_same_v) +#endif +#ifdef __fp32__ + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) { if constexpr(Rank == 3) { @@ -89,7 +91,7 @@ struct DeviceOperationInstanceFactory, std::tuple>; +#else +using KernelTypes = ::testing::Types>; +#endif TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes); TYPED_TEST(TestAvgPool2dFwd, Test_Pool) diff --git a/test/pool_fwd/test_avg_pool3d_fwd.cpp b/test/pool_fwd/test_avg_pool3d_fwd.cpp index 00cc3740fa..a99a3c698e 100644 --- a/test/pool_fwd/test_avg_pool3d_fwd.cpp +++ b/test/pool_fwd/test_avg_pool3d_fwd.cpp @@ -40,10 +40,12 @@ class TestAvgPool3dFwd : public ::testing::Test } } }; - +#ifdef __fp16__ using KernelTypes = ::testing::Types, std::tuple>; - +#else +using KernelTypes = ::testing::Types>; +#endif TYPED_TEST_SUITE(TestAvgPool3dFwd, KernelTypes); TYPED_TEST(TestAvgPool3dFwd, Test_Pool) { diff --git a/test/pool_fwd/test_max_pool2d_fwd.cpp b/test/pool_fwd/test_max_pool2d_fwd.cpp index 1cf1314f46..c9050848d2 100644 --- a/test/pool_fwd/test_max_pool2d_fwd.cpp +++ b/test/pool_fwd/test_max_pool2d_fwd.cpp @@ -59,10 +59,12 @@ class TestMaxPool2dFwd : public ::testing::Test } } }; - +#ifdef __fp16__ using KernelTypes = ::testing::Types, std::tuple>; - +#else +using KernelTypes = ::testing::Types>; +#endif TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes); TYPED_TEST(TestMaxPool2dFwd, Test_Pool) { diff --git a/test/pool_fwd/test_max_pool3d_fwd.cpp b/test/pool_fwd/test_max_pool3d_fwd.cpp index 0b0de4d90e..7c0aada473 100644 --- a/test/pool_fwd/test_max_pool3d_fwd.cpp +++ b/test/pool_fwd/test_max_pool3d_fwd.cpp @@ -60,8 +60,12 @@ class TestMaxPool3dFwd : public ::testing::Test } }; +#ifdef __fp16__ using KernelTypes = - ::testing::Types, std::tuple>; + ::testing::Types, std::tuple>; +#else +using KernelTypes = ::testing::Types>; +#endif TYPED_TEST_SUITE(TestMaxPool3dFwd, KernelTypes); TYPED_TEST(TestMaxPool3dFwd, Test_Pool) diff --git a/test/softmax/test_softmax_rank3.cpp b/test/softmax/test_softmax_rank3.cpp index 43ae11bf1f..f17672286f 100644 --- a/test/softmax/test_softmax_rank3.cpp +++ b/test/softmax/test_softmax_rank3.cpp @@ -10,8 +10,9 @@ template using I = ck::Number; - +#ifdef __fp16__ using F16 = ck::half_t; +#endif using F32 = float; template @@ -22,7 +23,9 @@ class TestSoftmax : public ck::TestSoftmax // clang-format off using KernelTypes = ::testing::Types< // InDataType, AccDataType, OutDataType, Rank +#ifdef __fp16__ std::tuple< F16, F32, F16, I<3>>, +#endif std::tuple< F32, F32, F32, I<3>> >; // clang-format on diff --git a/test/softmax/test_softmax_rank4.cpp b/test/softmax/test_softmax_rank4.cpp index 5cf96bbaa8..161aebf617 100644 --- a/test/softmax/test_softmax_rank4.cpp +++ b/test/softmax/test_softmax_rank4.cpp @@ -10,8 +10,9 @@ template using I = ck::Number; - +#ifdef __fp16__ using F16 = ck::half_t; +#endif using F32 = float; template @@ -22,7 +23,9 @@ class TestSoftmax : public ck::TestSoftmax // clang-format off using KernelTypes = ::testing::Types< // InDataType, AccDataType, OutDataType, Rank +#ifdef __fp16__ std::tuple< F16, F32, F16, I<4>>, +#endif std::tuple< F32, F32, F32, I<4>> >; // clang-format on