From 1f4d13b2b5dc44844131884882c97161aedeaf6c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:42:17 -0700 Subject: [PATCH] Split the instances by architecture. (#1223) * parse examples inside the add_example_executable function * fix the example 64 cmake file * add xdl flag to the gemm_bias_softmax_gemm_permute example * add filtering of tests based on architecture type * enable test_grouped_gemm for gfx9 only * enable test_transpose only for gfx9 * only linnk test_transpose if it gets built * split the gemm instances by architectures * split gemm_bilinear,grouped_conv_bwd_weight instances by targets * split instances by architecture * split grouped_conv instances by architecture * fix clang format * fix the if-else logic in group_conv headers * small fix for grouped convolution instances * fix the grouped conv bwd weight dl instances * fix client examples * only enable client examples 3 and 4 on gfx9 * set the gfx9 macro * make sure the architecture macros are set by cmake * use separate set of xdl/wmma flags for host code * sinmplify the main cmake file * add conv_fwd_bf8 instance declaration [ROCm/composable_kernel commit: ae57e5938e7fdfd049055a855910f66054e04163] --- CMakeLists.txt | 16 + .../02_gemm_add_add_fastgelu/CMakeLists.txt | 34 +- .../03_gemm_layernorm/CMakeLists.txt | 10 +- client_example/04_contraction/CMakeLists.txt | 23 +- .../07_grouped_convnd_fwd/CMakeLists.txt | 10 +- .../08_fused_attention/CMakeLists.txt | 10 +- client_example/09_quantization/CMakeLists.txt | 30 +- .../15_convnd_bwd_data/CMakeLists.txt | 10 +- .../15_gemm_add_multiply/CMakeLists.txt | 7 +- .../17_grouped_gemm_fastgelu/CMakeLists.txt | 6 +- client_example/20_splitk_gemm/CMakeLists.txt | 2 +- .../21_grouped_gemm_bias/CMakeLists.txt | 6 +- client_example/22_grouped_gemm/CMakeLists.txt | 18 +- .../24_grouped_conv_activation/CMakeLists.txt | 2 + client_example/25_wrapper/CMakeLists.txt | 4 +- client_example/CMakeLists.txt | 5 +- example/01_gemm/CMakeLists.txt | 11 +- example/02_gemm_bilinear/CMakeLists.txt | 23 +- example/03_gemm_bias_relu/CMakeLists.txt | 9 +- .../04_gemm_add_add_fastgelu/CMakeLists.txt | 35 +- example/09_convnd_fwd/CMakeLists.txt | 23 +- .../CMakeLists.txt | 34 +- example/14_gemm_quantization/CMakeLists.txt | 13 +- .../CMakeLists.txt | 67 +- example/17_convnd_bwd_data/CMakeLists.txt | 15 +- .../20_grouped_conv_bwd_weight/CMakeLists.txt | 34 +- example/21_gemm_layernorm/CMakeLists.txt | 16 +- example/26_contraction/CMakeLists.txt | 32 +- .../CMakeLists.txt | 5 +- .../CMakeLists.txt | 51 +- example/31_batched_gemm_gemm/CMakeLists.txt | 20 +- .../CMakeLists.txt | 14 +- example/35_splitK_gemm/CMakeLists.txt | 43 +- .../CMakeLists.txt | 31 +- .../40_conv2d_fwd_quantization/CMakeLists.txt | 39 +- .../41_grouped_conv_conv_fwd/CMakeLists.txt | 20 +- .../CMakeLists.txt | 9 +- ...=> gemm_bias_softmax_gemm_permute_xdl.cpp} | 0 example/52_im2col_col2im/CMakeLists.txt | 18 +- example/60_gemm_multi_ABD/CMakeLists.txt | 9 +- .../61_contraction_multi_ABD/CMakeLists.txt | 9 +- example/62_convnd_activ/CMakeLists.txt | 19 +- example/64_fpAintB_gemm/CMakeLists.txt | 8 +- example/CMakeLists.txt | 36 + include/ck/ck.hpp | 14 +- .../tensor_operation_instance/gpu/gemm.hpp | 583 ++------- .../gpu/gemm_bilinear.hpp | 8 +- .../tensor_operation_instance/gpu/gemm_dl.inc | 167 +++ .../gpu/gemm_wmma.inc | 34 + .../gpu/gemm_xdl.inc | 238 ++++ .../gpu/grouped_convolution_backward_data.hpp | 673 +++------- ...grouped_convolution_backward_data_wmma.inc | 243 ++++ .../grouped_convolution_backward_data_xdl.inc | 216 ++++ .../grouped_convolution_backward_weight.hpp | 889 ++++--------- ...grouped_convolution_backward_weight_dl.inc | 243 ++++ ...ouped_convolution_backward_weight_wmma.inc | 114 ++ ...rouped_convolution_backward_weight_xdl.inc | 228 ++++ .../gpu/grouped_convolution_forward.hpp | 1121 +++-------------- .../gpu/grouped_convolution_forward_dl.inc | 73 ++ .../gpu/grouped_convolution_forward_wmma.inc | 480 +++++++ .../gpu/grouped_convolution_forward_xdl.inc | 357 ++++++ .../gpu/CMakeLists.txt | 35 + .../gpu/batched_gemm/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../batched_gemm_bias_permute/CMakeLists.txt | 1 + .../gpu/batched_gemm_gemm/CMakeLists.txt | 1 + .../gpu/batched_gemm_reduce/CMakeLists.txt | 1 + .../batched_gemm_softmax_gemm/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/contraction_bilinear/CMakeLists.txt | 1 + .../gpu/contraction_scale/CMakeLists.txt | 1 + .../gpu/conv1d_bwd_data/CMakeLists.txt | 1 + .../gpu/conv2d_bwd_data/CMakeLists.txt | 1 + .../gpu/conv2d_fwd/CMakeLists.txt | 1 + .../gpu/conv2d_fwd_bias_relu/CMakeLists.txt | 1 + .../conv2d_fwd_bias_relu_add/CMakeLists.txt | 1 + .../gpu/conv3d_bwd_data/CMakeLists.txt | 1 + .../gpu/gemm_add/CMakeLists.txt | 1 + .../gpu/gemm_add_add_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_add_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_add_multiply/CMakeLists.txt | 1 + .../gpu/gemm_add_relu/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/gemm_add_silu/CMakeLists.txt | 1 + .../gpu/gemm_bias_add_reduce/CMakeLists.txt | 1 + .../gpu/gemm_bilinear/CMakeLists.txt | 1 + .../gpu/gemm_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_multiply_add/CMakeLists.txt | 1 + .../gpu/gemm_reduce/CMakeLists.txt | 1 + .../gpu/gemm_splitk/CMakeLists.txt | 1 + .../gpu/gemm_streamk/CMakeLists.txt | 1 + .../grouped_conv1d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv1d_fwd/CMakeLists.txt | 1 + .../grouped_conv2d_bwd_data/CMakeLists.txt | 1 + .../grouped_conv2d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 1 + .../grouped_conv3d_bwd_data/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../grouped_conv3d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../grouped_conv3d_fwd_scale/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/grouped_gemm/CMakeLists.txt | 1 + .../gpu/grouped_gemm_bias/CMakeLists.txt | 1 + .../gpu/grouped_gemm_fastgelu/CMakeLists.txt | 1 + .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 1 + .../gpu/quantization/CMakeLists.txt | 1 + profiler/src/CMakeLists.txt | 171 +-- test/CMakeLists.txt | 25 +- test/batched_gemm/CMakeLists.txt | 11 +- ...hed_gemm.cpp => test_batched_gemm_xdl.cpp} | 0 test/batched_gemm_gemm/CMakeLists.txt | 19 +- ...pp => test_batched_gemm_gemm_fp16_xdl.cpp} | 0 test/batched_gemm_reduce/CMakeLists.txt | 13 +- ...6.cpp => batched_gemm_reduce_fp16_xdl.cpp} | 0 test/batched_gemm_softmax_gemm/CMakeLists.txt | 19 +- ...st_batched_gemm_softmax_gemm_fp16_xdl.cpp} | 0 .../CMakeLists.txt | 50 +- ...mm_bias_softmax_gemm_permute_bf16_xdl.cpp} | 0 ...mm_bias_softmax_gemm_permute_fp16_xdl.cpp} | 0 ...ed_gemm_softmax_gemm_permute_bf16_xdl.cpp} | 0 ...ed_gemm_softmax_gemm_permute_fp16_xdl.cpp} | 0 test/contraction/CMakeLists.txt | 21 +- ...cpp => test_contraction_interface_xdl.cpp} | 0 ...ntraction.cpp => test_contraction_xdl.cpp} | 0 test/convnd_bwd_data/CMakeLists.txt | 11 +- ...d_bwd_data.cpp => convnd_bwd_data_xdl.cpp} | 0 test/convnd_fwd/CMakeLists.txt | 11 +- .../{convnd_fwd.cpp => convnd_fwd_xdl.cpp} | 0 test/gemm_add/CMakeLists.txt | 24 +- ...elu.cpp => test_gemm_add_fastgelu_xdl.cpp} | 2 +- ...dd_relu.cpp => test_gemm_add_relu_xdl.cpp} | 2 +- ...dd_silu.cpp => test_gemm_add_silu_xdl.cpp} | 2 +- ...est_gemm_add.hpp => test_gemm_add_xdl.hpp} | 0 test/gemm_layernorm/CMakeLists.txt | 19 +- ..._gemm_add_relu_add_layernorm_fp16_xdl.cpp} | 0 test/gemm_reduce/CMakeLists.txt | 2 +- ...duce_fp16.cpp => gemm_reduce_fp16_xdl.cpp} | 0 test/gemm_split_k/CMakeLists.txt | 9 +- ...mm_splitk.cpp => test_gemm_splitk_xdl.cpp} | 0 test/grouped_convnd_bwd_data/CMakeLists.txt | 31 +- ...test_grouped_convnd_bwd_data_xdl_wmma.cpp} | 0 test/grouped_convnd_bwd_weight/CMakeLists.txt | 26 +- ...st_grouped_convnd_bwd_weight_xdl_wmma.cpp} | 0 test/grouped_convnd_fwd/CMakeLists.txt | 16 +- ...ti_d_interface_compatibility_xdl_wmma.cpp} | 0 ...p => test_grouped_convnd_fwd_xdl_wmma.cpp} | 0 test/grouped_gemm/CMakeLists.txt | 27 +- ...pp => test_grouped_gemm_interface_xdl.cpp} | 0 ...k.cpp => test_grouped_gemm_splitk_xdl.cpp} | 0 test/normalization_bwd_data/CMakeLists.txt | 13 +- .../CMakeLists.txt | 13 +- test/permute_scale/CMakeLists.txt | 6 +- test/transpose/CMakeLists.txt | 13 +- ...t_transpose.cpp => test_transpose_xdl.cpp} | 0 test/wrapper/CMakeLists.txt | 6 +- ...per_gemm.cpp => test_wrapper_gemm_xdl.cpp} | 0 160 files changed, 3770 insertions(+), 3392 deletions(-) rename example/47_gemm_bias_softmax_gemm_permute/{gemm_bias_softmax_gemm_permute.cpp => gemm_bias_softmax_gemm_permute_xdl.cpp} (100%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc rename test/batched_gemm/{test_batched_gemm.cpp => test_batched_gemm_xdl.cpp} (100%) rename test/batched_gemm_gemm/{test_batched_gemm_gemm_fp16.cpp => test_batched_gemm_gemm_fp16_xdl.cpp} (100%) rename test/batched_gemm_reduce/{batched_gemm_reduce_fp16.cpp => batched_gemm_reduce_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm/{test_batched_gemm_softmax_gemm_fp16.cpp => test_batched_gemm_softmax_gemm_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp => test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp => test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_softmax_gemm_permute_bf16.cpp => test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_softmax_gemm_permute_fp16.cpp => test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp} (100%) rename test/contraction/{test_contraction_interface.cpp => test_contraction_interface_xdl.cpp} (100%) rename test/contraction/{test_contraction.cpp => test_contraction_xdl.cpp} (100%) rename test/convnd_bwd_data/{convnd_bwd_data.cpp => convnd_bwd_data_xdl.cpp} (100%) rename test/convnd_fwd/{convnd_fwd.cpp => convnd_fwd_xdl.cpp} (100%) rename test/gemm_add/{test_gemm_add_fastgelu.cpp => test_gemm_add_fastgelu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add_relu.cpp => test_gemm_add_relu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add_silu.cpp => test_gemm_add_silu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add.hpp => test_gemm_add_xdl.hpp} (100%) rename test/gemm_layernorm/{test_gemm_add_relu_add_layernorm_fp16.cpp => test_gemm_add_relu_add_layernorm_fp16_xdl.cpp} (100%) rename test/gemm_reduce/{gemm_reduce_fp16.cpp => gemm_reduce_fp16_xdl.cpp} (100%) rename test/gemm_split_k/{test_gemm_splitk.cpp => test_gemm_splitk_xdl.cpp} (100%) rename test/grouped_convnd_bwd_data/{test_grouped_convnd_bwd_data.cpp => test_grouped_convnd_bwd_data_xdl_wmma.cpp} (100%) rename test/grouped_convnd_bwd_weight/{test_grouped_convnd_bwd_weight.cpp => test_grouped_convnd_bwd_weight_xdl_wmma.cpp} (100%) rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp => test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp} (100%) rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd.cpp => test_grouped_convnd_fwd_xdl_wmma.cpp} (100%) rename test/grouped_gemm/{test_grouped_gemm_interface.cpp => test_grouped_gemm_interface_xdl.cpp} (100%) rename test/grouped_gemm/{test_grouped_gemm_splitk.cpp => test_grouped_gemm_splitk_xdl.cpp} (100%) rename test/transpose/{test_transpose.cpp => test_transpose_xdl.cpp} (100%) rename test/wrapper/{test_wrapper_gemm.cpp => test_wrapper_gemm_xdl.cpp} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index bdeba33eac..3c77f520ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,22 @@ if(GPU_TARGETS) else() message("Building CK for the following targets: ${AMDGPU_TARGETS}") endif() + +if (GPU_TARGETS) + if (GPU_TARGETS MATCHES "gfx9") + add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") + endif() + if (GPU_TARGETS MATCHES "gfx11") + add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") + endif() +else() + add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") +endif() + find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index 772b699955..4ba86026b2 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,27 +1,29 @@ -add_custom_target(client_gemm_fastgelu_examples) +if(GPU_TARGETS MATCHES "gfx9") + add_custom_target(client_gemm_fastgelu_examples) -add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) -target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) + target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu + add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) -add_custom_target(client_gemm_fastgelu_generic_examples) + add_custom_target(client_gemm_fastgelu_generic_examples) -add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) -target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) + target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) +endif() diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index 94b4576f64..8fedc84635 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) -target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) + target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) -add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) -target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) + add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) + target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +endif() diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index cd4a95124c..13c0375846 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,15 +1,16 @@ -add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) -target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) + target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) -target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) + target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) -target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) + target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) -target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) - -add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) -target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) + target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) + target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 40f1bba064..710eca9f49 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) -target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) + target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) -add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) -target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) + add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) + target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) +endif() \ No newline at end of file diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 9472be07b5..4bcde367dc 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_fused_attention fused_attention.cpp) -target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_fused_attention fused_attention.cpp) + target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_fused_attention_bias fused_attention_bias.cpp) -target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_fused_attention_bias fused_attention_bias.cpp) + target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index 65ad642ce2..d2d3a427e8 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,22 +1,22 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) -add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)) + add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_gemm_quantization gemm_quantization.cpp) -target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_gemm_quantization gemm_quantization.cpp) + target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) endif() diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt index f35cd82d79..8fc62bc2bb 100644 --- a/client_example/15_convnd_bwd_data/CMakeLists.txt +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) -add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) + add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) -target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) -target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/15_gemm_add_multiply/CMakeLists.txt b/client_example/15_gemm_add_multiply/CMakeLists.txt index 4b4d762003..a683f78571 100644 --- a/client_example/15_gemm_add_multiply/CMakeLists.txt +++ b/client_example/15_gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ - -add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) -target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) + target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt index fd315afbd2..39bef71814 100644 --- a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) -target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) + target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a3dc853767..05fcaa8103 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) endif() diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt index 92e31495c2..a09921e50a 100644 --- a/client_example/21_grouped_gemm_bias/CMakeLists.txt +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt index 0c3cb956f0..1e1c39681e 100644 --- a/client_example/22_grouped_gemm/CMakeLists.txt +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -1,11 +1,13 @@ -add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 074dcd9b97..e79dee9f7d 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -1,3 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") # Fwd scaleadd scaleadd relu add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) @@ -46,3 +47,4 @@ target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_ke add_executable(client_grouped_convnd_bwd_data_scale_fp16 grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp) target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt index fdfc1d8d2e..b1e9d20bfd 100644 --- a/client_example/25_wrapper/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -2,9 +2,7 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_img2col wrapper_img2col.cpp) target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") +if(GPU_TARGETS MATCHES "gfx9") add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 753f5e5ae5..3aa9efa315 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -48,7 +48,10 @@ else() endif() endif() -find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations) +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) +if(GPU_TARGETS MATCHES "gfx9") + find_package(composable_kernel COMPONENTS device_contraction_operations) +endif() find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2fa8e77462..39e3f2a2bd 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) -endif() add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) @@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -# FIXME: re-enable this example as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) @@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) +add_custom_target(example_gemm_wmma) +add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index d82c42d5a9..2c20b96eee 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,20 +1,3 @@ -list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) - add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) -endif() -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 2f5cba924d..35c54abac0 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 33ac1e7e77..ab19f819e8 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,29 +1,20 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_add_add_fastgelu_xdl) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) +add_custom_target(example_gemm_add_add_fastgelu_xdl) +add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) +add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - set(target 1) - endif() -endforeach() - -set(gpu_list "") +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) set(target 0) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 195f1857ed..61e9a43c3a 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,19 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) - add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) - add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 222a3b7c0b..ef8bea1850 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,25 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) +add_custom_target(example_convnd_fwd_reduce_xdl) +add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) +add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 9793e8b8a0..8703fa3ed7 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,12 +1,3 @@ -# dlops add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) -# xdlops -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) - add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 5955e1d6cb..1e12c16f30 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,48 +1,41 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_reduce_xdl) - add_custom_target(example_gemm_reduce_xdl_max) - add_custom_target(example_gemm_reduce_xdl_mean_meansquare) - add_custom_target(example_gemm_add_add_mean_meansquare_xdl) +add_custom_target(example_gemm_reduce_xdl) +add_custom_target(example_gemm_reduce_xdl_max) +add_custom_target(example_gemm_reduce_xdl_mean_meansquare) +add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) +add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) +add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) +add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) +add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) +add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) +add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) +add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) +add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) +add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) - add_example_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) +add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) - endif() - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) +endif() diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 7c6d10d8a0..39f9fb8ec0 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,14 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) - endif() - set(target 1) - endif() -endforeach() +add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) +endif() add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) if(result EQUAL 0) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index c28fca6fa2..497ea19e11 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,29 +1,15 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) +add_custom_target(example_grouped_conv_bwd_weight) +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) +add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - set(target 1) - endif() +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) - set(target 1) - endif() -endforeach() - -add_custom_target(example_grouped_conv_bwd_weight_dl) +add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) -add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index e231bc619b..2eb7052e1e 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,12 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 1a0489ce9c..f3d30cea2a 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear) # FP32 add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) # FP64 add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) # FP16 add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) # BF16 add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) -add_dependencies(example_contraction example_contraction_scale) -add_dependencies(example_contraction example_contraction_bilinear) +add_example_dependencies(example_contraction example_contraction_scale) +add_example_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index f343cc1910..ac54aebdc2 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,2 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) - -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 3a8c2ef52f..7acb1a1907 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,40 +1,23 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) +add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) +endif() # USE_BITINT_EXTENSION_INT4 - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) - endif() # USE_BITINT_EXTENSION_INT4 - - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 93f16c945f..8b648a7f73 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index c6cca7b586..519f539106 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,11 +1,9 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) - add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) - add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) - add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) +add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) +add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) +add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) add_custom_target(example_gemm_scale_softmax_gemm) diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 5277b32f63..9a62d85ace 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,32 +1,23 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_splitK_gemm_xdl) +add_custom_target(example_splitK_gemm_xdl) +add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) +add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) +add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) - add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) +add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) - add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) +add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) - add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) +add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) - endif() - - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) +endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 1ae179e950..72e6959649 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,27 +1,10 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) +add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) - set(target 1) - endif() -endforeach() - -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - - add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) - - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 2d804cafe9..991c1e464b 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,24 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - # Conv perlayer quantization - add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) - # Conv perchannel quantization - add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) - # Conv + bias + relu perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) - # Conv + bias + relu perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) - # Conv + bias + tanh perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - # Conv + bias + tanh perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) +# Conv perlayer quantization +add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) +# Conv perchannel quantization +add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) +# Conv + bias + relu perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) +# Conv + bias + relu perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) +# Conv + bias + tanh perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) +# Conv + bias + tanh perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index ae251e88d2..8ab56b21a6 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index 14432f6e23..df1956ca62 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp similarity index 100% rename from example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp rename to example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp diff --git a/example/52_im2col_col2im/CMakeLists.txt b/example/52_im2col_col2im/CMakeLists.txt index 4dc6c8b4e0..63ee1d4312 100644 --- a/example/52_im2col_col2im/CMakeLists.txt +++ b/example/52_im2col_col2im/CMakeLists.txt @@ -1,15 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_im2col_col2im) +add_custom_target(example_im2col_col2im) - add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) - add_example_dependencies(example_im2col_col2im example_image_to_column_f32) +add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) +add_example_dependencies(example_im2col_col2im example_image_to_column_f32) - add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) - add_example_dependencies(example_im2col_col2im example_column_to_image_f32) - - set(target 1) - endif() -endforeach() +add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) +add_example_dependencies(example_im2col_col2im example_column_to_image_f32) diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index 57bc0b33ef..d3974897fe 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt index 42500b64e6..1b8bd4cad2 100644 --- a/example/61_contraction_multi_ABD/CMakeLists.txt +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index 6eaddd3ff7..5a35f9b608 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -2,16 +2,9 @@ add_subdirectory(binary) add_subdirectory(multi_AB) add_subdirectory(unary) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl) - # ScaleAdd ScaleAdd Relu - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) - set(target 1) - endif() -endforeach() +add_custom_target(example_convnd_activ_xdl) +# ScaleAdd ScaleAdd Relu +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt index 89cc2d7f62..b3c77b3bd7 100644 --- a/example/64_fpAintB_gemm/CMakeLists.txt +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -1,5 +1,3 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_fpAintB_gemm_wmma) - add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) - add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) -endif() +add_custom_target(example_fpAintB_gemm_wmma) +add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) +add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c19ba93b69..5465adb779 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,12 @@ include_directories(BEFORE add_custom_target(examples) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(FILE_NAME) + add_dependencies(EXAMPLE_NAME FILE_NAME) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) @@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index c93d1d0639..0bda8b7590 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -45,6 +45,10 @@ #endif // define general macros for various architectures +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) +#define __gfx9__ +#endif #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif @@ -62,8 +66,7 @@ // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 -#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 @@ -75,8 +78,7 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ - defined(__gfx94__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 @@ -89,7 +91,7 @@ // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_MFMA #endif @@ -120,7 +122,7 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index ee9d977096..50c18fc22e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -12,398 +12,21 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "gemm_dl.inc" +#endif +#ifdef CK_USE_WMMA +#include "gemm_wmma.inc" +#endif +#ifdef CK_USE_XDL +#include "gemm_xdl.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS) -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS) -void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS) -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP64 -void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP8 -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif - -void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - template > op_ptrs; +#ifdef DL_KERNELS if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } +#ifdef CK_ENABLE_FP16 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + } + } #endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); + } + } +#endif +#endif // DL_KERNELS + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif +#endif + +#ifdef CK_USE_XDL + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( op_ptrs); @@ -452,10 +196,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( op_ptrs); @@ -463,10 +203,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( op_ptrs); @@ -474,10 +210,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( op_ptrs); @@ -490,57 +222,25 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } #endif @@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); -#endif } } #endif @@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs); } } +#endif #endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 1a518a5302..6ee88bd855 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 +#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA) void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector> op_ptrs; -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc new file mode 100644 index 0000000000..44a11f6284 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#if defined(CK_ENABLE_FP16) +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_FP32) +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_INT8) +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc new file mode 100644 index 0000000000..c97298c258 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc new file mode 100644 index 0000000000..82a1dc425a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP64 +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP8 +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 09885ccd90..9a70a47274 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -10,439 +10,18 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_data_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_data_wmma.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -// conv2d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); - -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -// conv3d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - std::vector>>& instances); -#endif template > op_ptrs; + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 2) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && @@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } } - else if constexpr(NumDimSpatial == 3) + if constexpr(NumDimSpatial == 3) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -593,35 +142,144 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( op_ptrs); } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( op_ptrs); @@ -638,46 +296,16 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( op_ptrs); } #endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - op_ptrs); - } -#endif -#ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - op_ptrs); - } -#endif -#ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( op_ptrs); @@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory< #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc new file mode 100644 index 0000000000..fb2407bcc3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv2d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc new file mode 100644 index 0000000000..7ad0218410 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); +#endif + +// conv3d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index b8ca2c5fac..dc56b8f4b2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,565 +12,20 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "grouped_convolution_backward_weight_dl.inc" +#endif +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_weight_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_weight_wmma.inc" +#endif namespace ck { namespace tensor_operation { namespace device { namespace instance { -// xdl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef DL_KERNELS -// dl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif -#endif - template > op_ptrs; +#ifdef DL_KERNELS if constexpr(NumDimSpatial == 1) { if constexpr(is_same_v && is_same_v && @@ -621,10 +77,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -632,10 +85,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -644,19 +94,14 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - op_ptrs); } #endif } if constexpr(is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -683,6 +128,174 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif // DL_KERNELS +#ifdef CK_USE_XDL + if constexpr(NumDimSpatial == 1) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + op_ptrs); + } #endif } } @@ -696,10 +309,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( op_ptrs); } @@ -709,10 +318,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( op_ptrs); } @@ -723,10 +328,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( op_ptrs); } @@ -740,10 +341,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); } @@ -753,10 +350,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); } @@ -767,10 +360,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( op_ptrs); } @@ -787,10 +376,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( op_ptrs); } @@ -800,16 +385,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -818,13 +395,70 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( op_ptrs); } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + op_ptrs); + } +#endif + } + } +#endif +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && @@ -842,50 +476,17 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - op_ptrs); add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( op_ptrs); } #endif -#ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && is_same_v && @@ -897,18 +498,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - op_ptrs); - } #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc new file mode 100644 index 0000000000..59190a13e5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc new file mode 100644 index 0000000000..315547ca56 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv3d backward weight +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc new file mode 100644 index 0000000000..5562d236e6 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index b9712542a8..24a5f9a5cb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -12,908 +12,21 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "grouped_convolution_forward_dl.inc" +#endif +#ifdef CK_USE_XDL +#include "grouped_convolution_forward_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_forward_wmma.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_BF16 -// grouped conv1d forward, GNWC/GKXC/GNWK -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -// grouped conv2d forward, NHWGC/GKYXC/NHWGK -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); - -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif - template > op_ptrs; +#ifdef DL_KERNELS + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } +#endif + } + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + } +#endif + } +#endif // DL_KERNELS + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 1 && is_same_v && is_same_v && is_same_v) { @@ -1000,35 +154,13 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); } #endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1037,23 +169,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); - } -#endif } if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) { - #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1061,15 +181,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1077,15 +188,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1093,16 +195,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); - } #endif } @@ -1121,12 +213,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1142,11 +228,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); } #endif } @@ -1188,12 +269,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1209,6 +284,99 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); + } +#endif + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( @@ -1217,6 +385,7 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc new file mode 100644 index 0000000000..0ea24d0929 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc new file mode 100644 index 0000000000..942674ef99 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 +// grouped conv1d forward, GNWC/GKXC/GNWK +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 0a12e1c49e..c035e7e564 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -36,12 +36,27 @@ function(add_instance_library INSTANCE_NAME) endif() endforeach() endif() + # Do not build DL instances if DL_KERNELS macro is not set foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + # Do not build XDL instances if gfx9 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + # Do not build WMMA instances if gfx11 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_library(${INSTANCE_NAME} OBJECT ${ARGN}) @@ -124,6 +139,26 @@ FOREACH(subdir_path ${dir_list}) message("Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) + message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") + set(add_inst 0) + endif() if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 69b6ddc754..1227a77a38 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(BATCHED_GEMM_INSTANCES) list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index d0e9b265af..5c8470f7cb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index cd9c95c066..8082a8c275 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index 865a31e79a..2aa607429d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_gemm_instance device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index 28226fabac..51bbdf1d7c 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt index 6244477e16..e43eb07fb6 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_softmax_gemm_instance device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3fd4e03449..f1fb0646e4 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES) list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 87a6bbba4e..a28c6717dd 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index a0918d9d3f..b91de832e4 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_SCALE_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt index 75a3670761..796a9b2402 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv1d_bwd_data_instance device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index 49dfc01fd9..2da5155117 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_BWD_DATA_INSTANCES) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index ba0ca32517..04b313d075 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONV2D_FWD_INSTANCES) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt index 670cd94fc9..4304d8996c 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt index 68d5f582fd..40a6b1ff09 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_add_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt index db92208fd7..ec4a8a2864 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv3d_bwd_data_instance device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index fe85bb7ead..298da1fbef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index bbf81a5fa2..04ae90bc5b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_add_fastgelu_instance device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 63b4a00c99..45d6abce01 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_fastgelu_instance device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt index eb9345cbad..d859078ca9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_multiply_instance device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 969361de9a..043bdab001 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index 97693a2566..b9aeb6a6db 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt index c10d4773a7..e6ca64cdc1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_silu_instance device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt index ccada3a85e..f29943d93b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_bias_add_reduce_instance device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index 426edeed74..61892e708c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index 17d27ab150..2f45401ec6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_fastgelu_instance device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index 6cbd7528e7..aba9806a74 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_MULTIPLY_ADD_INSTANCES) list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 2b2cf8c774..7ee3efe7f5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_reduce_instance device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 059b6a720f..dac86d7707 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_SPLITK_INSTANCES) list(APPEND GEMM_SPLITK_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt index 8dd0112a6b..c854b16eeb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_streamk_instance # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index cfd829f87e..ab4313d89e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV1D_BWD_WEIGHT xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index f51a484bb5..ca4ea515bb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_conv1d_fwd_instance xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 93d5bd7422..ad430340ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library( device_grouped_conv2d_bwd_data_instance xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 8a896b06c7..340ddfb3f0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 2715a8cf21..1d3c3747d3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS add_instance_library(device_grouped_conv2d_fwd_instance #xdl # GNHWC, GKYXC, GNHWK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 836e671bf2..29fa8fa3c5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt index e1cb975291..ae6dcb9880 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt index b7901a2815..fa48f0edcc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 968e8dea2f..8b89dcf7ec 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 3825b92af4..972fb54031 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index 49706588d6..436c37fd58 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index 45d270d554..f36d55d367 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 08fb23afc9..1076249447 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index ae89caaeef..1be1db7d1d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index de7537af47..2625e6cbe8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt index ef8a440c1a..167dfa9a6f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_bias_instance device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index 648f2146cb..8e9693e691 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index ac22543bef..1ee3d0add4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index c22a6e9e96..5d50902be8 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 11ae285167..cb6ffbec6c 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -2,19 +2,6 @@ set(PROFILER_SOURCES profiler.cpp profile_gemm.cpp - profile_gemm_splitk.cpp - profile_gemm_bias_add_reduce.cpp - profile_gemm_add_multiply.cpp - profile_gemm_multiply_add.cpp - profile_gemm_reduce.cpp - profile_batched_gemm.cpp - profile_batched_gemm_reduce.cpp - profile_conv_fwd.cpp - profile_conv_fwd_bias_relu.cpp - profile_conv_fwd_bias_relu_add.cpp - profile_conv_bwd_data.cpp - profile_grouped_conv_fwd.cpp - profile_grouped_conv_bwd_weight.cpp profile_reduce.cpp profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp @@ -29,36 +16,57 @@ set(PROFILER_SOURCES profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp - profile_grouped_conv_bwd_data.cpp profile_conv_tensor_rearrange.cpp profile_transpose.cpp profile_permute_scale.cpp ) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) + list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) + list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + +endif() + +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +endif() + if(DL_KERNELS) list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) -endif() - -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -68,25 +76,6 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) @@ -96,39 +85,65 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) endif() - +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +endif() if(DL_KERNELS) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) endif() rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a0f90256c0..720ab468ea 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) @@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 9482821b68..759cf3da67 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_batched_gemm test_batched_gemm.cpp) +add_gtest_executable(test_batched_gemm test_batched_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/batched_gemm/test_batched_gemm.cpp b/test/batched_gemm/test_batched_gemm_xdl.cpp similarity index 100% rename from test/batched_gemm/test_batched_gemm.cpp rename to test/batched_gemm/test_batched_gemm_xdl.cpp diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 03f1d3a4eb..2b3288ef9d 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_gemm) - add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) - add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_gemm) + target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) + add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) +endif() diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp rename to test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 32c6ee85d1..c5868e4d7a 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,11 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) - set(target 1) - endif() +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) endif() -endforeach() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp rename to test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index c011a6a3c5..c042d7e000 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm) - add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) - add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_softmax_gemm) + target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) + add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) +endif() diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp rename to test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3164863eef..2e09073540 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,29 +1,21 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm_permute) - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) - endif() - - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) - endif() - set(target 1) - endif() -endforeach() \ No newline at end of file +add_custom_target(test_batched_gemm_softmax_gemm_permute) +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) +endif() diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt index a86e72fddb..3ba0d82f0e 100644 --- a/test/contraction/CMakeLists.txt +++ b/test/contraction/CMakeLists.txt @@ -1,13 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) - add_gtest_executable(test_contraction test_contraction.cpp) - target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) - target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - set(target 1) - endif() +if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) + add_gtest_executable(test_contraction test_contraction_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) endif() -endforeach() + add_gtest_executable(test_contraction_interface test_contraction_interface_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) + endif() +endif() diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface_xdl.cpp similarity index 100% rename from test/contraction/test_contraction_interface.cpp rename to test/contraction/test_contraction_interface_xdl.cpp diff --git a/test/contraction/test_contraction.cpp b/test/contraction/test_contraction_xdl.cpp similarity index 100% rename from test/contraction/test_contraction.cpp rename to test/contraction/test_contraction_xdl.cpp diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index f734b46f53..e68a9b243c 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) +add_gtest_executable(test_convnd_bwd_data convnd_bwd_data_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp similarity index 100% rename from test/convnd_bwd_data/convnd_bwd_data.cpp rename to test/convnd_bwd_data/convnd_bwd_data_xdl.cpp diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 745aceffc9..ba6d16a0d5 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) +add_gtest_executable(test_convnd_fwd convnd_fwd_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) - set(target 1) - endif() -endforeach() +endif() diff --git a/test/convnd_fwd/convnd_fwd.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp similarity index 100% rename from test/convnd_fwd/convnd_fwd.cpp rename to test/convnd_fwd/convnd_fwd_xdl.cpp diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 7df3f90abc..ab4c781847 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,11 +1,19 @@ -add_gtest_executable(test_gemm_add test_gemm_add.hpp) -target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +endif() -add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp) -target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +endif() -add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp) -target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +endif() -add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp) -target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +endif() diff --git a/test/gemm_add/test_gemm_add_fastgelu.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_fastgelu.cpp rename to test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index c1c55140a0..1b12ab7528 100644 --- a/test/gemm_add/test_gemm_add_fastgelu.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddFastgelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_relu.cpp b/test/gemm_add/test_gemm_add_relu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_relu.cpp rename to test/gemm_add/test_gemm_add_relu_xdl.cpp index ba6aab36bd..e8b769b1cb 100644 --- a/test/gemm_add/test_gemm_add_relu.cpp +++ b/test/gemm_add/test_gemm_add_relu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddRelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_silu.cpp b/test/gemm_add/test_gemm_add_silu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_silu.cpp rename to test/gemm_add/test_gemm_add_silu_xdl.cpp index d4dd6fa38b..75fa59a8e7 100644 --- a/test/gemm_add/test_gemm_add_silu.cpp +++ b/test/gemm_add/test_gemm_add_silu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddSilu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add.hpp b/test/gemm_add/test_gemm_add_xdl.hpp similarity index 100% rename from test/gemm_add/test_gemm_add.hpp rename to test/gemm_add/test_gemm_add_xdl.hpp diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index bfc4404bd8..d1102a561a 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_gemm_layernorm) - add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) - add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) - set(target 1) - endif() - endif() -endforeach() +add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_gemm_layernorm) + target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) +endif() diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp similarity index 100% rename from test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp rename to test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 42a53c3048..121ecde609 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) endif() \ No newline at end of file diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/gemm_reduce/gemm_reduce_fp16.cpp rename to test/gemm_reduce/gemm_reduce_fp16_xdl.cpp diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index caf30fca59..4b66dddef9 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) +add_gtest_executable(test_gemm_splitk test_gemm_splitk_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) - set(target 1) endif() -endforeach() diff --git a/test/gemm_split_k/test_gemm_splitk.cpp b/test/gemm_split_k/test_gemm_splitk_xdl.cpp similarity index 100% rename from test/gemm_split_k/test_gemm_splitk.cpp rename to test/gemm_split_k/test_gemm_splitk_xdl.cpp diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 305c568ee9..3507989bae 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,19 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index d7d6f8a3d6..34cdc63cd9 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,20 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight_xdl_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) - target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp rename to test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 1ce878d5ca..4f245d63cd 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,8 +1,14 @@ -add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) -target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +endif() add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +endif() -add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +endif() diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 8c57b667e2..f47685cf91 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,14 +1,13 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_grouped_gemm) - add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) - add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) - target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) - target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) - - add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) - set(target 1) - endif() -endforeach() +add_custom_target(test_grouped_gemm) + +add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) +endif() + +add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_interface) +endif() diff --git a/test/grouped_gemm/test_grouped_gemm_interface.cpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_interface.cpp rename to test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp diff --git a/test/grouped_gemm/test_grouped_gemm_splitk.cpp b/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_splitk.cpp rename to test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt index 1b6decfed7..65f33da74d 100644 --- a/test/normalization_bwd_data/CMakeLists.txt +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_data) add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt index f3579aad08..afb78dc58e 100644 --- a/test/normalization_bwd_gamma_beta/CMakeLists.txt +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_gamma_beta) add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) diff --git a/test/permute_scale/CMakeLists.txt b/test/permute_scale/CMakeLists.txt index be6aaf94aa..d63cb79910 100644 --- a/test/permute_scale/CMakeLists.txt +++ b/test/permute_scale/CMakeLists.txt @@ -1,6 +1,4 @@ add_custom_target(test_permute) add_gtest_executable(test_permute_scale test_permute_scale.cpp) -if(result EQUAL 0) - target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) - add_dependencies(test_permute test_permute_scale) -endif() +target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) +add_dependencies(test_permute test_permute_scale) diff --git a/test/transpose/CMakeLists.txt b/test/transpose/CMakeLists.txt index 530cc9d72d..fb9379bea9 100644 --- a/test/transpose/CMakeLists.txt +++ b/test/transpose/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_transpose test_transpose.cpp) - target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) - set(target 1) - endif() -endforeach() +add_gtest_executable(test_transpose test_transpose_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) +endif() diff --git a/test/transpose/test_transpose.cpp b/test/transpose/test_transpose_xdl.cpp similarity index 100% rename from test/transpose/test_transpose.cpp rename to test/transpose/test_transpose_xdl.cpp diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 383707828c..1eb6c35db2 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -12,10 +12,8 @@ add_dependencies(test_wrapper test_wrapper_copy) add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp) target_link_libraries(test_wrapper_partition PRIVATE utility) add_dependencies(test_wrapper test_wrapper_partition) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") - add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp) +add_gtest_executable(test_wrapper_gemm test_wrapper_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_wrapper_gemm PRIVATE utility) add_dependencies(test_wrapper test_wrapper_gemm) endif() diff --git a/test/wrapper/test_wrapper_gemm.cpp b/test/wrapper/test_wrapper_gemm_xdl.cpp similarity index 100% rename from test/wrapper/test_wrapper_gemm.cpp rename to test/wrapper/test_wrapper_gemm_xdl.cpp