diff --git a/CMakeLists.txt b/CMakeLists.txt index d22bb94513..f2a06a091f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,6 +145,22 @@ if(GPU_TARGETS) else() message("Building CK for the following targets: ${AMDGPU_TARGETS}") endif() + +if (GPU_TARGETS) + if (GPU_TARGETS MATCHES "gfx9") + add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") + endif() + if (GPU_TARGETS MATCHES "gfx11") + add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") + endif() +else() + add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") +endif() + find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index 772b699955..4ba86026b2 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,27 +1,29 @@ -add_custom_target(client_gemm_fastgelu_examples) +if(GPU_TARGETS MATCHES "gfx9") + add_custom_target(client_gemm_fastgelu_examples) -add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) -target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) + target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu + add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) -add_custom_target(client_gemm_fastgelu_generic_examples) + add_custom_target(client_gemm_fastgelu_generic_examples) -add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) -target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) + target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) +endif() diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index 94b4576f64..8fedc84635 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) -target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) + target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) -add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) -target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) + add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) + target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +endif() diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index cd4a95124c..13c0375846 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,15 +1,16 @@ -add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) -target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) + target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) -target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) + target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) -target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) + target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) -target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) - -add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) -target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) + target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) + target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 40f1bba064..710eca9f49 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) -target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) + target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) -add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) -target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) + add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) + target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) +endif() \ No newline at end of file diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 9472be07b5..4bcde367dc 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_fused_attention fused_attention.cpp) -target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_fused_attention fused_attention.cpp) + target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_fused_attention_bias fused_attention_bias.cpp) -target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_fused_attention_bias fused_attention_bias.cpp) + target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index 65ad642ce2..d2d3a427e8 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,22 +1,22 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) -add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)) + add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_gemm_quantization gemm_quantization.cpp) -target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_gemm_quantization gemm_quantization.cpp) + target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) endif() diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt index f35cd82d79..8fc62bc2bb 100644 --- a/client_example/15_convnd_bwd_data/CMakeLists.txt +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) -add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) + add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) -target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) -target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/15_gemm_add_multiply/CMakeLists.txt b/client_example/15_gemm_add_multiply/CMakeLists.txt index 4b4d762003..a683f78571 100644 --- a/client_example/15_gemm_add_multiply/CMakeLists.txt +++ b/client_example/15_gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ - -add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) -target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) + target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt index fd315afbd2..39bef71814 100644 --- a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) -target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) + target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a3dc853767..05fcaa8103 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) endif() diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt index 92e31495c2..a09921e50a 100644 --- a/client_example/21_grouped_gemm_bias/CMakeLists.txt +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt index 0c3cb956f0..1e1c39681e 100644 --- a/client_example/22_grouped_gemm/CMakeLists.txt +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -1,11 +1,13 @@ -add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 074dcd9b97..e79dee9f7d 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -1,3 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") # Fwd scaleadd scaleadd relu add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) @@ -46,3 +47,4 @@ target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_ke add_executable(client_grouped_convnd_bwd_data_scale_fp16 grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp) target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt index fdfc1d8d2e..b1e9d20bfd 100644 --- a/client_example/25_wrapper/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -2,9 +2,7 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_img2col wrapper_img2col.cpp) target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") +if(GPU_TARGETS MATCHES "gfx9") add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 753f5e5ae5..3aa9efa315 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -48,7 +48,10 @@ else() endif() endif() -find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations) +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) +if(GPU_TARGETS MATCHES "gfx9") + find_package(composable_kernel COMPONENTS device_contraction_operations) +endif() find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2fa8e77462..39e3f2a2bd 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) -endif() add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) @@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -# FIXME: re-enable this example as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) @@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) +add_custom_target(example_gemm_wmma) +add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index d82c42d5a9..2c20b96eee 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,20 +1,3 @@ -list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) - add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) -endif() -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 2f5cba924d..35c54abac0 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 33ac1e7e77..ab19f819e8 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,29 +1,20 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_add_add_fastgelu_xdl) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) +add_custom_target(example_gemm_add_add_fastgelu_xdl) +add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) +add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - set(target 1) - endif() -endforeach() - -set(gpu_list "") +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) set(target 0) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 195f1857ed..61e9a43c3a 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,19 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) - add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) - add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 222a3b7c0b..ef8bea1850 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,25 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) +add_custom_target(example_convnd_fwd_reduce_xdl) +add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) +add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 9793e8b8a0..8703fa3ed7 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,12 +1,3 @@ -# dlops add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) -# xdlops -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) - add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 84040fcf5c..550dafb066 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) -add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) -add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 2c1feafce3..1a2bcfb33e 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple<>; -using EDataType = F32; +using EDataType = F16; using ALayout = Row; using BLayout = Col; @@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on struct ProblemSize final @@ -298,9 +298,9 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ms.push_back(256 + 256 * i); - problem_size.Ns.push_back(256); - problem_size.Ks.push_back(128); + problem_size.Ms.push_back(128 + rand() % 128); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(1024); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp similarity index 99% rename from example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp rename to example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 9fd63cba77..0a63a29843 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; using BDataType = F8; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using DsDataType = ck::Tuple<>; using EDataType = F16; @@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on struct ProblemSize final diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 5955e1d6cb..1e12c16f30 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,48 +1,41 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_reduce_xdl) - add_custom_target(example_gemm_reduce_xdl_max) - add_custom_target(example_gemm_reduce_xdl_mean_meansquare) - add_custom_target(example_gemm_add_add_mean_meansquare_xdl) +add_custom_target(example_gemm_reduce_xdl) +add_custom_target(example_gemm_reduce_xdl_max) +add_custom_target(example_gemm_reduce_xdl_mean_meansquare) +add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) +add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) +add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) +add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) +add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) +add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) +add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) +add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) +add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) +add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) - add_example_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) +add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) - endif() - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) +endif() diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 7c6d10d8a0..39f9fb8ec0 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,14 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) - endif() - set(target 1) - endif() -endforeach() +add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) +endif() add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) if(result EQUAL 0) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index c28fca6fa2..497ea19e11 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,29 +1,15 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) +add_custom_target(example_grouped_conv_bwd_weight) +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) +add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - set(target 1) - endif() +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) - set(target 1) - endif() -endforeach() - -add_custom_target(example_grouped_conv_bwd_weight_dl) +add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) -add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index e231bc619b..2eb7052e1e 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,12 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 1a0489ce9c..f3d30cea2a 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear) # FP32 add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) # FP64 add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) # FP16 add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) # BF16 add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) -add_dependencies(example_contraction example_contraction_scale) -add_dependencies(example_contraction example_contraction_bilinear) +add_example_dependencies(example_contraction example_contraction_scale) +add_example_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index f343cc1910..ac54aebdc2 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,2 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) - -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 3a8c2ef52f..7acb1a1907 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,40 +1,23 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) +add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) +endif() # USE_BITINT_EXTENSION_INT4 - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) - endif() # USE_BITINT_EXTENSION_INT4 - - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 93f16c945f..8b648a7f73 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index c6cca7b586..519f539106 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,11 +1,9 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) - add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) - add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) - add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) +add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) +add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) +add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) add_custom_target(example_gemm_scale_softmax_gemm) diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 5277b32f63..9a62d85ace 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,32 +1,23 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_splitK_gemm_xdl) +add_custom_target(example_splitK_gemm_xdl) +add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) +add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) +add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) - add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) +add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) - add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) +add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) - add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) +add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) - endif() - - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) +endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 1ae179e950..72e6959649 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,27 +1,10 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) +add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) - set(target 1) - endif() -endforeach() - -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - - add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) - - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 2d804cafe9..991c1e464b 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,24 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - # Conv perlayer quantization - add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) - # Conv perchannel quantization - add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) - # Conv + bias + relu perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) - # Conv + bias + relu perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) - # Conv + bias + tanh perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - # Conv + bias + tanh perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) +# Conv perlayer quantization +add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) +# Conv perchannel quantization +add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) +# Conv + bias + relu perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) +# Conv + bias + relu perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) +# Conv + bias + tanh perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) +# Conv + bias + tanh perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index ae251e88d2..8ab56b21a6 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index a963399dc7..3cf4812509 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp) add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) +add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) +add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) add_example_executable(example_elementwise_permute elementwise_permute.cpp) if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942")) add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp new file mode 100644 index 0000000000..8819bb65e6 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 +using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise:: + BinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + BinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 2> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute.cpp b/example/44_elementwise_permute/elementwise_permute.cpp index 24e161c6d3..d3c3085eb8 100644 --- a/example/44_elementwise_permute/elementwise_permute.cpp +++ b/example/44_elementwise_permute/elementwise_permute.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -51,32 +39,7 @@ int main() std::vector ncdhw = {16, 8, 8, 8, 8}; std::vector ndhwc = {16, 8, 8, 8, 8}; - Tensor a(ncdhw); - Tensor b(ndhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - /**std::array a_strides = { - static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[4]), - 1}; - std::array b_strides = { - static_cast(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[2] * ndhwc[3] * ndhwc[4]), - 1, - static_cast(ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[4])};**/ std::array a_strides = { static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), @@ -93,6 +56,20 @@ int main() 1}; ck::ranges::copy(ncdhw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -126,10 +103,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_3d.cpp b/example/44_elementwise_permute/elementwise_permute_3d.cpp index f3aca57c35..47d8c4de65 100644 --- a/example/44_elementwise_permute/elementwise_permute_3d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_3d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<4>, // InScalarPerVectorSeq ck::Sequence<4>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -59,10 +47,13 @@ int main() const int W = 5; const int D = 16; - std::vector ncdhw = {N, C, D, H, W}; - std::vector ndhwc = {N, D, H, W, C}; - Tensor a(ncdhw); - Tensor b(ndhwc); + std::array ab_lengths{N, C, H, W, D}; + std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W + std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -74,10 +65,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, C, H, W, D}; - std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W - std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,11 +81,12 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; + std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * + ab_lengths[3] * ab_lengths[4]; std::size_t num_btype = - sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + - sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -111,10 +99,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 1b28a901cb..3ea1aa4bf8 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +44,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -77,9 +54,22 @@ int main() 1, static_cast(nhwc[2] * nhwc[3]), static_cast(nhwc[3])}; - ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -111,10 +101,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp index 30231a3758..1747e6dd8b 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - const std::vector& shape_nchw, - Functor functor) -{ - for(std::size_t n = 0; n < shape_nchw[0]; ++n) - for(std::size_t c = 0; c < shape_nchw[1]; ++c) - for(std::size_t h = 0; h < shape_nchw[2]; ++h) - for(std::size_t w = 0; w < shape_nchw[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -54,13 +40,16 @@ int main() const int N = 120; const int C = 128; const int H = 32; - const int W = 1024; + const int W = 32; - std::vector nchw = {N, C, H, W}; - std::vector nhwc = {N, H, W, C}; + std::array ab_lengths{N, H, W, C}; - Tensor a(nchw); - Tensor b(nhwc); + std::array a_strides = {C * H * W, W, 1, H * W}; + std::array b_strides = {H * W * C, W * C, C, 1}; + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -72,11 +61,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, H, W, C}; - - std::array a_strides = {C * H * W, W, 1, H * W}; - std::array b_strides = {H * W * C, W * C, C, 1}; - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,10 +78,11 @@ int main() float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t flop = + std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]; - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -110,11 +95,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); - Tensor host_b(nhwc); - host_elementwise4D, Tensor, PassThrough>( - host_b, a, nchw, PassThrough{}); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index f832601f07..13c67fce05 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -6,9 +6,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -21,11 +23,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; @@ -60,8 +48,21 @@ int main() std::vector nchw = {16, 8, 32, 64}; std::vector nhwc = {16, 32, 64, 8}; - Tensor a(nchw); - Tensor b(nhwc); + std::array ab_lengths; + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -84,22 +85,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -113,11 +106,10 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; - - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t num_btype = + (2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -129,10 +121,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index bae85f53c1..0a0f6fec10 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,29 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +112,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index fe7acd3010..fc664186be 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; bool time_kernel = true; - std::vector nchw = {5, 4, 2, 3}; - std::vector nhwc = {5, 2, 3, 4}; - Tensor a(nchw); - Tensor b(nhwc); + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; @@ -84,22 +86,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -129,10 +123,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index aebdb37d9b..a0c416318a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,28 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +111,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp new file mode 100644 index 0000000000..050300eed2 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2 +using TrinaryAddUnaryScaleSquare = + ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + TrinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 3> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor& a2 = as[2]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + float gamma = 4.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a2.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + a2_device_buf.ToDevice(a2.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + a2_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "A2 (nchw): " << a2.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + ref_invoker.Run(ref_argument); + + const double threshold = std::pow(2, -10) * 2; + b_device_buf.FromDevice(b.mData.data()); + pass &= ck::utils::check_err( + b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold); + } + + return pass ? 0 : 1; +} diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index 14432f6e23..df1956ca62 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp similarity index 100% rename from example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp rename to example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp diff --git a/example/52_im2col_col2im/CMakeLists.txt b/example/52_im2col_col2im/CMakeLists.txt index 4dc6c8b4e0..63ee1d4312 100644 --- a/example/52_im2col_col2im/CMakeLists.txt +++ b/example/52_im2col_col2im/CMakeLists.txt @@ -1,15 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_im2col_col2im) +add_custom_target(example_im2col_col2im) - add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) - add_example_dependencies(example_im2col_col2im example_image_to_column_f32) +add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) +add_example_dependencies(example_im2col_col2im example_image_to_column_f32) - add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) - add_example_dependencies(example_im2col_col2im example_column_to_image_f32) - - set(target 1) - endif() -endforeach() +add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) +add_example_dependencies(example_im2col_col2im example_column_to_image_f32) diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index 57bc0b33ef..d3974897fe 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt index 42500b64e6..1b8bd4cad2 100644 --- a/example/61_contraction_multi_ABD/CMakeLists.txt +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index 6eaddd3ff7..5a35f9b608 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -2,16 +2,9 @@ add_subdirectory(binary) add_subdirectory(multi_AB) add_subdirectory(unary) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl) - # ScaleAdd ScaleAdd Relu - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) - set(target 1) - endif() -endforeach() +add_custom_target(example_convnd_activ_xdl) +# ScaleAdd ScaleAdd Relu +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt index 89cc2d7f62..b3c77b3bd7 100644 --- a/example/64_fpAintB_gemm/CMakeLists.txt +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -1,5 +1,3 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_fpAintB_gemm_wmma) - add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) - add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) -endif() +add_custom_target(example_fpAintB_gemm_wmma) +add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) +add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c19ba93b69..5465adb779 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,12 @@ include_directories(BEFORE add_custom_target(examples) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(FILE_NAME) + add_dependencies(EXAMPLE_NAME FILE_NAME) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) @@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index c93d1d0639..0bda8b7590 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -45,6 +45,10 @@ #endif // define general macros for various architectures +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) +#define __gfx9__ +#endif #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif @@ -62,8 +66,7 @@ // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 -#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 @@ -75,8 +78,7 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ - defined(__gfx94__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 @@ -89,7 +91,7 @@ // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_MFMA #endif @@ -120,7 +122,7 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index d197c56ab8..c98ec6e2aa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -23,6 +23,7 @@ namespace device { template (gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - barrier_count_finished, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } id_off += grid_size_grp; id_local += grid_size_grp; @@ -193,8 +224,11 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename ComputeType = ADataType, + typename ALDSType = ComputeType, + typename BLDSType = ComputeType> struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK{}; static constexpr auto I2 = Number<2>{}; + using AComputeType = ComputeType; + using BComputeType = ComputeType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, - ComputeType, + AComputeType, + BComputeType, AccDataType, CShuffleDataType, DsDataType, @@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; + LoopSched, + PipelineVer, + ALDSType, + BLDSType>; template struct OffsettedBlockToCTileMapMLoops @@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK, - GemmSpec, - ALayout, - BLayout, - DsLayout, - ELayout, - DsDataType, - Block2ETileMap, - GroupedGemmBlock2ETileMap, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - e_global_memory_operation_, - has_main_k_block_loop_>; + if(arg.k_batch_ == 1) + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + false, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; - return launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - reinterpret_cast(arg.p_workspace_), - arg.barrier_size_grp_, - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.k_batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + nullptr, + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + true, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } }; constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; constexpr auto Set = InMemoryDataOperationEnum::Set; - // For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced - // in IsSupportedArgument function + // For bf16 datatype only kbatch = 1 scenario is supported. This condition is + // enforced in IsSupportedArgument function if constexpr(std::is_same::value) { if(has_main_k_block_loop) @@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::max(x0_converted, x1_converted); + } +}; + +struct Min +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::min(x0_converted, x1_converted); + } +}; + +struct Multiply +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 * type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 * x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 * x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 * x1; + }; +}; + struct ScaleAdd { __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} diff --git a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp new file mode 100644 index 0000000000..6d1d6b57c5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// y = UnaryOp0(UnaryOp1(...(x))) +template +struct UnaryCombinedOp +{ + __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // Execute first unary op to copy data to y + unary_ops_.At(Number<0>{})(y, x); + + static_for<1, Tuple::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); }); + }; + + Tuple unary_ops_; +}; + +// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1)) +template +struct BinaryWithUnaryCombinedOp +{ + __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1) + : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result); + }; + + private: + BinaryOp binary_op_; + UnaryOp0 unary_op0_; + UnaryOp1 unary_op1_; +}; + +// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2)) +template +struct TrinaryWithUnaryCombinedOp +{ + __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, + BinaryOp0 binary_op1, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1, + UnaryOp2 unary_op2) + : binary_op0_(binary_op0), + binary_op1_(binary_op1), + unary_op0_(unary_op0), + unary_op1_(unary_op1), + unary_op2_(unary_op2) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const + { + + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + Y unary_x2_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + unary_op2_(unary_x2_tmp_result, x2); + binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result); + binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result); + }; + + private: + BinaryOp0 binary_op0_{}; + BinaryOp1 binary_op1_{}; + UnaryOp0 unary_op0_{}; + UnaryOp1 unary_op1_{}; + UnaryOp2 unary_op2_{}; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 9c64ad4dfa..1add81e69e 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -12,10 +12,6 @@ namespace ck { namespace tensor_operation { namespace element_wise { -#if CK_WORKAROUND_SWDEV_383542 -extern "C" __device__ float __ocml_native_recip_f32(float); -#endif - struct PassThroughPack2 { template @@ -449,11 +445,7 @@ struct FastGelu const float u = x * (c1 * x * x + c2); const float emu = __expf(u); -#if !CK_WORKAROUND_SWDEV_383542 - y = x * __frcp_rn(1.f + emu); -#else - y = x * __ocml_native_recip_f32(1.f + emu); -#endif + y = x * ck::math::rcp(1.f + emu); } template <> @@ -559,6 +551,244 @@ struct TanH }; }; +struct ACos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acos(x); + }; +}; + +struct Neg +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::neg(x); + }; +}; + +struct ATan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atan(x); + }; +}; + +struct Sin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sin(x); + }; +}; + +struct ASinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asinh(x); + }; +}; + +struct Cos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cos(x); + }; +}; + +struct ACosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acosh(x); + }; +}; + +struct Tan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::tan(x); + }; +}; + +struct ATanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atanh(x); + }; +}; + +struct SinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sinh(x); + }; +}; + +struct Ceil +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::ceil(x); + }; +}; + +struct Exp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::exp(x); + }; +}; + +struct CosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cosh(x); + }; +}; + +struct Floor +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::floor(x); + }; +}; + +struct Log +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::log(x); + }; +}; + +struct ASin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asin(x); + }; +}; + +struct Rcp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::rcp(x); + }; +}; + struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp index 2a906a1432..4d1a09b445 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp @@ -118,8 +118,16 @@ struct GridwiseElementwise __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); const index_t m1_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); - const auto thread_grid_offset = - make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + const auto input_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); + const auto output_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); using ThisThreadBlock = ThisThreadBlock; // If src and dst have same vector dim, then: @@ -157,9 +165,9 @@ struct GridwiseElementwise uniform_sequence_gen_t, uniform_sequence_gen_t, uniform_sequence_gen_t>{in_grid_desc_tuple, - thread_grid_offset, + input_thread_grid_offset, out_grid_desc_tuple, - thread_grid_offset, + output_thread_grid_offset, elementwise_op}; global_to_global_transfer.Run( in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 4f37049b20..ae93af192e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -31,7 +31,8 @@ namespace ck { // D0, D1, ... and E have the same layout template + PipelineVersion PipelineVer, + typename ALDSType, + typename BLDSType> struct GridwiseGemmMultipleD_xdl_splitk_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ComputeType), + return math::max(a_block_space_size_aligned * sizeof(ALDSType) + + b_block_space_size_aligned * sizeof(BLDSType), c_block_size * sizeof(CShuffleDataType)); } @@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle InMemoryDataOperationEnum EGlobalMemoryDataOperation, index_t NumDTensor_, typename DsDataType_, + bool Zeroing, typename AGridDesc_KBatch_AK0_M_AK1, typename BGridDesc_KBatch_BK0_N_BK1, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, @@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - ComputeType, + ALDSType, decltype(a_grid_desc_kbatch_ak0_m_ak1), decltype(a_block_desc_kbatch_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BDataType, - ComputeType, + BLDSType, decltype(b_grid_desc_kbatch_bk0_n_bk1), decltype(b_block_desc_kbatch_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // sanity check constexpr index_t KPack = math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, - ComputeType, + ALDSType, + BLDSType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeType, + BComputeType>(); -#if 1 - if(block_work_idx[I0] == 0) + if constexpr(Zeroing) { - const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; - const index_t numNThreads = NPerBlock / nThreadSize; - const index_t numMThreads = BlockSize / numNThreads; - const index_t mThreadSize = MPerBlock / numMThreads; - - const index_t m_tid = get_thread_local_1d_id() / numNThreads; - const index_t n_tid = get_thread_local_1d_id() % numNThreads; - - auto c_thread_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, I1, Number{})); - - StaticBuffer - e_thread_zero_buf; - - auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< - EDataType, - EDataType, - decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), - decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), - ck::tensor_operation::element_wise::PassThrough, - Sequence<1, mThreadSize, 1, nThreadSize>, - Sequence<0, 1, 2, 3>, - 3, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - InMemoryDataOperationEnum::Set, - 1, - true>{e_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], - m_tid * mThreadSize, - block_work_idx[I2], - n_tid * nThreadSize), - ck::tensor_operation::element_wise::PassThrough{}}; - - c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, - make_tuple(I0, I0, I0, I0), - e_thread_zero_buf, - e_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_buf); - - __syncthreads(); - - if(threadIdx.x == 0) + if(block_work_idx[I0] == 0) { - atomicAdd(barrier_count_finished, 1); + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __builtin_amdgcn_s_barrier(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } } } -#endif auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); @@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // shuffle C and write out { - if(threadIdx.x == 0) + if constexpr(Zeroing) { - while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + __builtin_amdgcn_s_barrier(); } - __syncthreads(); - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -951,18 +960,131 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle } }); - if(threadIdx.x == 0) + if constexpr(Zeroing) { - index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); - - if(k_id_finished_t == KBatch) + if(threadIdx.x == 0) { - *barrier_count_finished = 0; + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } } } } } + template + __device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run, true>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } + template ( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - barrier_count_finished, - KBatch, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_kbatch_ak0_m_ak1, - b_grid_desc_kbatch_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); - } - else - { - Run>( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - barrier_count_finished, - KBatch, - a_element_op, - b_element_op, - ck::tensor_operation::element_wise::PassThrough{}, - a_grid_desc_kbatch_ak0_m_ak1, - b_grid_desc_kbatch_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); - } + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + nullptr, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } }; diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index a07fde3da3..2b921cdc7c 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,10 @@ namespace ck { namespace math { +#if CK_WORKAROUND_SWDEV_383542 +extern "C" __device__ float __ocml_native_recip_f32(float); +#endif + // math functions for the host, some are implemented by calling C++ std functions static inline __host__ float abs(float x) { return std::abs(x); }; @@ -111,6 +115,276 @@ inline __host__ double tanh(double x) return std::tanh(x); }; +template +inline __host__ T acos(T x) +{ + return ck::type_convert(std::acosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acos(float x) +{ + return std::acosf(x); +}; + +template <> +inline __host__ double acos(double x) +{ + return std::acos(x); +}; + +template +inline __host__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __host__ float neg(float x) +{ + return -x; +}; + +template <> +inline __host__ double neg(double x) +{ + return -x; +}; + +template <> +inline __host__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __host__ int8_t neg(int8_t x) +{ + return -x; +}; + +template +inline __host__ T atan(T x) +{ + return ck::type_convert(std::atanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atan(float x) +{ + return std::atanf(x); +}; + +template <> +inline __host__ double atan(double x) +{ + return std::atan(x); +}; + +template +inline __host__ T sin(T x) +{ + return ck::type_convert(std::sinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sin(float x) +{ + return std::sinf(x); +}; + +template <> +inline __host__ double sin(double x) +{ + return std::sin(x); +}; + +template +inline __host__ T asin(T x) +{ + return ck::type_convert(std::asinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asin(float x) +{ + return std::asinf(x); +}; + +template <> +inline __host__ double asin(double x) +{ + return std::asin(x); +}; + +template +inline __host__ T asinh(T x) +{ + return ck::type_convert(std::asinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asinh(float x) +{ + return std::asinhf(x); +}; + +template <> +inline __host__ double asinh(double x) +{ + return std::asinh(x); +}; + +template +inline __host__ T cos(T x) +{ + return ck::type_convert(std::cosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cos(float x) +{ + return std::cosf(x); +}; + +template <> +inline __host__ double cos(double x) +{ + return std::cos(x); +}; + +template +inline __host__ T acosh(T x) +{ + return ck::type_convert(std::acoshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acosh(float x) +{ + return std::acoshf(x); +}; + +template <> +inline __host__ double acosh(double x) +{ + return std::acosh(x); +}; + +template +inline __host__ T tan(T x) +{ + return ck::type_convert(std::tanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float tan(float x) +{ + return std::tanf(x); +}; + +template <> +inline __host__ double tan(double x) +{ + return std::tan(x); +}; + +template +inline __host__ T atanh(T x) +{ + return ck::type_convert(std::atanhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atanh(float x) +{ + return std::atanhf(x); +}; + +template <> +inline __host__ double atanh(double x) +{ + return std::atanh(x); +}; + +template +inline __host__ T sinh(T x) +{ + return ck::type_convert(std::sinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sinh(float x) +{ + return std::sinhf(x); +}; + +template <> +inline __host__ double sinh(double x) +{ + return std::sinh(x); +}; + +template +inline __host__ T ceil(T x) +{ + return ck::type_convert(std::ceilf(ck::type_convert(x))); +}; + +template <> +inline __host__ float ceil(float x) +{ + return std::ceilf(x); +}; + +template <> +inline __host__ double ceil(double x) +{ + return std::ceil(x); +}; + +template +inline __host__ T cosh(T x) +{ + return ck::type_convert(std::coshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cosh(float x) +{ + return std::coshf(x); +}; + +template <> +inline __host__ double cosh(double x) +{ + return std::cosh(x); +}; + +template +inline __host__ T floor(T x) +{ + return ck::type_convert(std::floorf(ck::type_convert(x))); +}; + +template <> +inline __host__ float floor(float x) +{ + return std::floorf(x); +}; + +template <> +inline __host__ double floor(double x) +{ + return std::floor(x); +}; + +template +inline __host__ T rcp(T x) +{ + return ck::type_convert(1.f / ck::type_convert(x)); +}; + template inline __host__ T exp(T x) { @@ -282,6 +556,286 @@ inline __device__ double tanh(double x) return ::tanh(x); }; +template +inline __device__ T acos(T x) +{ + return ck::type_convert(::acosf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acos(float x) +{ + return ::acosf(x); +}; + +template <> +inline __device__ double acos(double x) +{ + return ::acos(x); +}; + +template +inline __device__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __device__ float neg(float x) +{ + return -x; +}; + +template <> +inline __device__ double neg(double x) +{ + return -x; +}; + +template <> +inline __device__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __device__ int8_t neg(int8_t x) +{ + return -x; +}; + +template <> +inline __device__ half_t neg(half_t x) +{ + return __hneg(x); +}; + +template +inline __device__ T atan(T x) +{ + return ck::type_convert(::atanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atan(float x) +{ + return ::atanf(x); +}; + +template <> +inline __device__ double atan(double x) +{ + return ::atan(x); +}; + +template +inline __device__ T sin(T x) +{ + return ck::type_convert(::sinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sin(float x) +{ + return ::sinf(x); +}; + +template <> +inline __device__ double sin(double x) +{ + return ::sin(x); +}; + +template <> +inline __device__ half_t sin(half_t x) +{ + return ::hsin(x); +}; + +template +inline __device__ T asin(T x) +{ + return ck::type_convert(::asinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asin(float x) +{ + return ::asinf(x); +}; + +template <> +inline __device__ double asin(double x) +{ + return ::asin(x); +}; + +template +inline __device__ T asinh(T x) +{ + return ck::type_convert(::asinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asinh(float x) +{ + return ::asinhf(x); +}; + +template <> +inline __device__ double asinh(double x) +{ + return ::asinh(x); +}; + +template +inline __device__ T acosh(T x) +{ + return ck::type_convert(::acoshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acosh(float x) +{ + return ::acoshf(x); +}; + +template <> +inline __device__ double acosh(double x) +{ + return ::acosh(x); +}; + +template +inline __device__ T tan(T x) +{ + return ck::type_convert(::tanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float tan(float x) +{ + return ::tanf(x); +}; + +template <> +inline __device__ double tan(double x) +{ + return ::tan(x); +}; + +template +inline __device__ T atanh(T x) +{ + return ck::type_convert(::atanhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atanh(float x) +{ + return ::atanhf(x); +}; + +template <> +inline __device__ double atanh(double x) +{ + return ::atanh(x); +}; + +template +inline __device__ T sinh(T x) +{ + return ck::type_convert(::sinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sinh(float x) +{ + return ::sinhf(x); +}; + +template <> +inline __device__ double sinh(double x) +{ + return ::sinh(x); +}; + +template +inline __device__ T ceil(T x) +{ + return ck::type_convert(::ceilf(ck::type_convert(x))); +}; + +template <> +inline __device__ float ceil(float x) +{ + return ::ceilf(x); +}; + +template <> +inline __device__ double ceil(double x) +{ + return ::ceil(x); +}; + +template <> +inline __device__ half_t ceil(half_t x) +{ + return ::hceil(x); +}; + +template +inline __device__ T cosh(T x) +{ + return ck::type_convert(::coshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float cosh(float x) +{ + return ::coshf(x); +}; + +template <> +inline __device__ double cosh(double x) +{ + return ::cosh(x); +}; + +template +inline __device__ T floor(T x) +{ + return ck::type_convert(::floorf(ck::type_convert(x))); +}; + +template <> +inline __device__ float floor(float x) +{ + return ::floorf(x); +}; + +template <> +inline __device__ double floor(double x) +{ + return ::floor(x); +}; + +template <> +inline __device__ half_t floor(half_t x) +{ + return ::hfloor(x); +}; + +template +inline __device__ T rcp(T x) +{ +#if !CK_WORKAROUND_SWDEV_383542 + return __frcp_rn(x); +#else + return __ocml_native_recip_f32(x); +#endif +}; + template inline __device__ T exp(T x) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp new file mode 100644 index 0000000000..470641fff7 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceElementwise : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + : a_tensors_{a_tensors}, b_tensor_{b_tensor}, element_op_{element_op} + { + } + + const std::array, NumATensors>& a_tensors_; + Tensor& b_tensor_; + ElementOp element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceElementwise::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumATensors == 1) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx)); + }); + } + else if constexpr(NumATensors == 2) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx), arg.a_tensors_[1](idx)); + }); + } + else if constexpr(NumATensors == 3) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), + arg.a_tensors_[0](idx), + arg.a_tensors_[1](idx), + arg.a_tensors_[2](idx)); + }); + } + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + { + return Argument{a_tensors, b_tensor, element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceElementwise" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index ee9d977096..50c18fc22e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -12,398 +12,21 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "gemm_dl.inc" +#endif +#ifdef CK_USE_WMMA +#include "gemm_wmma.inc" +#endif +#ifdef CK_USE_XDL +#include "gemm_xdl.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS) -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS) -void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS) -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP64 -void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP8 -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif - -void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - template > op_ptrs; +#ifdef DL_KERNELS if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } +#ifdef CK_ENABLE_FP16 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + } + } #endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); + } + } +#endif +#endif // DL_KERNELS + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif +#endif + +#ifdef CK_USE_XDL + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( op_ptrs); @@ -452,10 +196,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( op_ptrs); @@ -463,10 +203,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( op_ptrs); @@ -474,10 +210,6 @@ struct DeviceOperationInstanceFactory< else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( op_ptrs); @@ -490,57 +222,25 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); -#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } #endif @@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); -#endif } } #endif @@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs); } } +#endif #endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 1a518a5302..6ee88bd855 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 +#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA) void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector> op_ptrs; -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc new file mode 100644 index 0000000000..44a11f6284 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#if defined(CK_ENABLE_FP16) +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_FP32) +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_INT8) +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc new file mode 100644 index 0000000000..c97298c258 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc new file mode 100644 index 0000000000..82a1dc425a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP64 +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP8 +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 09885ccd90..9a70a47274 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -10,439 +10,18 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_data_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_data_wmma.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -// conv2d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); - -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -// conv3d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - std::vector>>& instances); -#endif template > op_ptrs; + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 2) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && @@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } } - else if constexpr(NumDimSpatial == 3) + if constexpr(NumDimSpatial == 3) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -593,35 +142,144 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( op_ptrs); } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( op_ptrs); @@ -638,46 +296,16 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( op_ptrs); } #endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - op_ptrs); - } -#endif -#ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - op_ptrs); - } -#endif -#ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( op_ptrs); @@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory< #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc new file mode 100644 index 0000000000..fb2407bcc3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv2d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc new file mode 100644 index 0000000000..7ad0218410 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); +#endif + +// conv3d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index b8ca2c5fac..dc56b8f4b2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,565 +12,20 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "grouped_convolution_backward_weight_dl.inc" +#endif +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_weight_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_weight_wmma.inc" +#endif namespace ck { namespace tensor_operation { namespace device { namespace instance { -// xdl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef DL_KERNELS -// dl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif -#endif - template > op_ptrs; +#ifdef DL_KERNELS if constexpr(NumDimSpatial == 1) { if constexpr(is_same_v && is_same_v && @@ -621,10 +77,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -632,10 +85,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -644,19 +94,14 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - op_ptrs); } #endif } if constexpr(is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -683,6 +128,174 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif // DL_KERNELS +#ifdef CK_USE_XDL + if constexpr(NumDimSpatial == 1) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + op_ptrs); + } #endif } } @@ -696,10 +309,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( op_ptrs); } @@ -709,10 +318,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( op_ptrs); } @@ -723,10 +328,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( op_ptrs); } @@ -740,10 +341,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); } @@ -753,10 +350,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); } @@ -767,10 +360,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( op_ptrs); } @@ -787,10 +376,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( op_ptrs); } @@ -800,16 +385,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -818,13 +395,70 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - op_ptrs); -#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( op_ptrs); } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + op_ptrs); + } +#endif + } + } +#endif +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && @@ -842,50 +476,17 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - op_ptrs); add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( op_ptrs); } #endif -#ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - op_ptrs); - } -#endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && is_same_v && @@ -897,18 +498,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - op_ptrs); - } #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc new file mode 100644 index 0000000000..59190a13e5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc new file mode 100644 index 0000000000..315547ca56 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv3d backward weight +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc new file mode 100644 index 0000000000..5562d236e6 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index b9712542a8..24a5f9a5cb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -12,908 +12,21 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef DL_KERNELS +#include "grouped_convolution_forward_dl.inc" +#endif +#ifdef CK_USE_XDL +#include "grouped_convolution_forward_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_forward_wmma.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_BF16 -// grouped conv1d forward, GNWC/GKXC/GNWK -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -// grouped conv2d forward, NHWGC/GKYXC/NHWGK -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); - -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif - template > op_ptrs; +#ifdef DL_KERNELS + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } +#endif + } + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + } +#endif + } +#endif // DL_KERNELS + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 1 && is_same_v && is_same_v && is_same_v) { @@ -1000,35 +154,13 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); } #endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1037,23 +169,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); - } -#endif } if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) { - #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1061,15 +181,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1077,15 +188,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1093,16 +195,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); - } #endif } @@ -1121,12 +213,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1142,11 +228,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); } #endif } @@ -1188,12 +269,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1209,6 +284,99 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); + } +#endif + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( @@ -1217,6 +385,7 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc new file mode 100644 index 0000000000..0ea24d0929 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc new file mode 100644 index 0000000000..942674ef99 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 +// grouped conv1d forward, GNWC/GKXC/GNWK +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 0a12e1c49e..c035e7e564 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -36,12 +36,27 @@ function(add_instance_library INSTANCE_NAME) endif() endforeach() endif() + # Do not build DL instances if DL_KERNELS macro is not set foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + # Do not build XDL instances if gfx9 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + # Do not build WMMA instances if gfx11 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_library(${INSTANCE_NAME} OBJECT ${ARGN}) @@ -124,6 +139,26 @@ FOREACH(subdir_path ${dir_list}) message("Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) + message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") + set(add_inst 0) + endif() if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 69b6ddc754..1227a77a38 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(BATCHED_GEMM_INSTANCES) list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index d0e9b265af..5c8470f7cb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index cd9c95c066..8082a8c275 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index 865a31e79a..2aa607429d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_gemm_instance device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index 28226fabac..51bbdf1d7c 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt index 6244477e16..e43eb07fb6 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_softmax_gemm_instance device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3fd4e03449..f1fb0646e4 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES) list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 87a6bbba4e..a28c6717dd 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index a0918d9d3f..b91de832e4 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_SCALE_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt index 75a3670761..796a9b2402 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv1d_bwd_data_instance device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index 49dfc01fd9..2da5155117 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_BWD_DATA_INSTANCES) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index ba0ca32517..04b313d075 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONV2D_FWD_INSTANCES) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt index 670cd94fc9..4304d8996c 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt index 68d5f582fd..40a6b1ff09 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_add_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt index db92208fd7..ec4a8a2864 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv3d_bwd_data_instance device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index fe85bb7ead..298da1fbef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index bbf81a5fa2..04ae90bc5b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_add_fastgelu_instance device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 63b4a00c99..45d6abce01 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_fastgelu_instance device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt index eb9345cbad..d859078ca9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_multiply_instance device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 969361de9a..043bdab001 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index 97693a2566..b9aeb6a6db 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt index c10d4773a7..e6ca64cdc1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_silu_instance device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt index ccada3a85e..f29943d93b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_bias_add_reduce_instance device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index 426edeed74..61892e708c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index 17d27ab150..2f45401ec6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_fastgelu_instance device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index 6cbd7528e7..aba9806a74 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_MULTIPLY_ADD_INSTANCES) list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 2b2cf8c774..7ee3efe7f5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_reduce_instance device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 059b6a720f..dac86d7707 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_SPLITK_INSTANCES) list(APPEND GEMM_SPLITK_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt index 8dd0112a6b..c854b16eeb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_streamk_instance # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index cfd829f87e..ab4313d89e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV1D_BWD_WEIGHT xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index f51a484bb5..ca4ea515bb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_conv1d_fwd_instance xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 93d5bd7422..ad430340ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library( device_grouped_conv2d_bwd_data_instance xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 8a896b06c7..340ddfb3f0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 2715a8cf21..1d3c3747d3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS add_instance_library(device_grouped_conv2d_fwd_instance #xdl # GNHWC, GKYXC, GNHWK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 836e671bf2..29fa8fa3c5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt index e1cb975291..ae6dcb9880 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt index b7901a2815..fa48f0edcc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 968e8dea2f..8b89dcf7ec 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 3825b92af4..972fb54031 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index 49706588d6..436c37fd58 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index 45d270d554..f36d55d367 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 08fb23afc9..1076249447 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index ae89caaeef..1be1db7d1d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index de7537af47..2625e6cbe8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt index ef8a440c1a..167dfa9a6f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_bias_instance device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index 648f2146cb..8e9693e691 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index ac22543bef..1ee3d0add4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index c22a6e9e96..5d50902be8 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp) diff --git a/profiler/include/profiler/profile_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp index c69e36142d..186a24501e 100644 --- a/profiler/include/profiler/profile_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -14,6 +14,8 @@ #include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -21,14 +23,6 @@ #include "ck/library/utility/literals.hpp" namespace ck { -template -void reference_permute_scale(HostTensorB& b_tensor, - const HostTensorA& a_tensor, - ElementOp tensor_op) -{ - b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); }); -} - namespace profiler { template @@ -46,7 +40,8 @@ bool profile_permute_scale_impl(int do_verification, using ElementOp = ck::tensor_operation::element_wise::Scale; float scale = 2.f; - Tensor a(lengths_vector, input_strides_vector); + std::array, 1> as = {Tensor(lengths_vector, input_strides_vector)}; + Tensor& a = as[0]; Tensor b(lengths_vector, output_strides_vector); Tensor host_b(lengths_vector, output_strides_vector); @@ -83,7 +78,14 @@ bool profile_permute_scale_impl(int do_verification, if(do_verification) { - reference_permute_scale(host_b, a, ElementOp{scale}); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, ElementOp>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, ElementOp{scale}); + + ref_invoker.Run(ref_argument); } auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 11ae285167..cb6ffbec6c 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -2,19 +2,6 @@ set(PROFILER_SOURCES profiler.cpp profile_gemm.cpp - profile_gemm_splitk.cpp - profile_gemm_bias_add_reduce.cpp - profile_gemm_add_multiply.cpp - profile_gemm_multiply_add.cpp - profile_gemm_reduce.cpp - profile_batched_gemm.cpp - profile_batched_gemm_reduce.cpp - profile_conv_fwd.cpp - profile_conv_fwd_bias_relu.cpp - profile_conv_fwd_bias_relu_add.cpp - profile_conv_bwd_data.cpp - profile_grouped_conv_fwd.cpp - profile_grouped_conv_bwd_weight.cpp profile_reduce.cpp profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp @@ -29,36 +16,57 @@ set(PROFILER_SOURCES profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp - profile_grouped_conv_bwd_data.cpp profile_conv_tensor_rearrange.cpp profile_transpose.cpp profile_permute_scale.cpp ) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) + list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) + list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + +endif() + +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +endif() + if(DL_KERNELS) list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) -endif() - -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -68,25 +76,6 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) @@ -96,39 +85,65 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) endif() - +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +endif() if(DL_KERNELS) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) endif() rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a0f90256c0..720ab468ea 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) @@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 9482821b68..759cf3da67 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_batched_gemm test_batched_gemm.cpp) +add_gtest_executable(test_batched_gemm test_batched_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/batched_gemm/test_batched_gemm.cpp b/test/batched_gemm/test_batched_gemm_xdl.cpp similarity index 100% rename from test/batched_gemm/test_batched_gemm.cpp rename to test/batched_gemm/test_batched_gemm_xdl.cpp diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 03f1d3a4eb..2b3288ef9d 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_gemm) - add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) - add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_gemm) + target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) + add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) +endif() diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp rename to test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 32c6ee85d1..c5868e4d7a 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,11 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) - set(target 1) - endif() +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) endif() -endforeach() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp rename to test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index c011a6a3c5..c042d7e000 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm) - add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) - add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_softmax_gemm) + target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) + add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) +endif() diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp rename to test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3164863eef..2e09073540 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,29 +1,21 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm_permute) - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) - endif() - - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) - endif() - set(target 1) - endif() -endforeach() \ No newline at end of file +add_custom_target(test_batched_gemm_softmax_gemm_permute) +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) +endif() diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt index a86e72fddb..3ba0d82f0e 100644 --- a/test/contraction/CMakeLists.txt +++ b/test/contraction/CMakeLists.txt @@ -1,13 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) - add_gtest_executable(test_contraction test_contraction.cpp) - target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) - target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - set(target 1) - endif() +if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) + add_gtest_executable(test_contraction test_contraction_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) endif() -endforeach() + add_gtest_executable(test_contraction_interface test_contraction_interface_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) + endif() +endif() diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface_xdl.cpp similarity index 100% rename from test/contraction/test_contraction_interface.cpp rename to test/contraction/test_contraction_interface_xdl.cpp diff --git a/test/contraction/test_contraction.cpp b/test/contraction/test_contraction_xdl.cpp similarity index 100% rename from test/contraction/test_contraction.cpp rename to test/contraction/test_contraction_xdl.cpp diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index f734b46f53..e68a9b243c 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) +add_gtest_executable(test_convnd_bwd_data convnd_bwd_data_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp similarity index 100% rename from test/convnd_bwd_data/convnd_bwd_data.cpp rename to test/convnd_bwd_data/convnd_bwd_data_xdl.cpp diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 745aceffc9..ba6d16a0d5 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) +add_gtest_executable(test_convnd_fwd convnd_fwd_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) - set(target 1) - endif() -endforeach() +endif() diff --git a/test/convnd_fwd/convnd_fwd.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp similarity index 100% rename from test/convnd_fwd/convnd_fwd.cpp rename to test/convnd_fwd/convnd_fwd_xdl.cpp diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 7df3f90abc..ab4c781847 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,11 +1,19 @@ -add_gtest_executable(test_gemm_add test_gemm_add.hpp) -target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +endif() -add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp) -target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +endif() -add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp) -target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +endif() -add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp) -target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +endif() diff --git a/test/gemm_add/test_gemm_add_fastgelu.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_fastgelu.cpp rename to test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index c1c55140a0..1b12ab7528 100644 --- a/test/gemm_add/test_gemm_add_fastgelu.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddFastgelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_relu.cpp b/test/gemm_add/test_gemm_add_relu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_relu.cpp rename to test/gemm_add/test_gemm_add_relu_xdl.cpp index ba6aab36bd..e8b769b1cb 100644 --- a/test/gemm_add/test_gemm_add_relu.cpp +++ b/test/gemm_add/test_gemm_add_relu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddRelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_silu.cpp b/test/gemm_add/test_gemm_add_silu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_silu.cpp rename to test/gemm_add/test_gemm_add_silu_xdl.cpp index d4dd6fa38b..75fa59a8e7 100644 --- a/test/gemm_add/test_gemm_add_silu.cpp +++ b/test/gemm_add/test_gemm_add_silu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddSilu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add.hpp b/test/gemm_add/test_gemm_add_xdl.hpp similarity index 100% rename from test/gemm_add/test_gemm_add.hpp rename to test/gemm_add/test_gemm_add_xdl.hpp diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index bfc4404bd8..d1102a561a 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_gemm_layernorm) - add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) - add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) - set(target 1) - endif() - endif() -endforeach() +add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_gemm_layernorm) + target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) +endif() diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp similarity index 100% rename from test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp rename to test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 42a53c3048..121ecde609 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) endif() \ No newline at end of file diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/gemm_reduce/gemm_reduce_fp16.cpp rename to test/gemm_reduce/gemm_reduce_fp16_xdl.cpp diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index caf30fca59..4b66dddef9 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) +add_gtest_executable(test_gemm_splitk test_gemm_splitk_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) - set(target 1) endif() -endforeach() diff --git a/test/gemm_split_k/test_gemm_splitk.cpp b/test/gemm_split_k/test_gemm_splitk_xdl.cpp similarity index 100% rename from test/gemm_split_k/test_gemm_splitk.cpp rename to test/gemm_split_k/test_gemm_splitk_xdl.cpp diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 305c568ee9..3507989bae 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,19 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index d7d6f8a3d6..34cdc63cd9 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,20 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight_xdl_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) - target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp rename to test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 1ce878d5ca..4f245d63cd 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,8 +1,14 @@ -add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) -target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +endif() add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +endif() -add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +endif() diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 8c57b667e2..f47685cf91 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,14 +1,13 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_grouped_gemm) - add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) - add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) - target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) - target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) - - add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) - set(target 1) - endif() -endforeach() +add_custom_target(test_grouped_gemm) + +add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) +endif() + +add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_interface) +endif() diff --git a/test/grouped_gemm/test_grouped_gemm_interface.cpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_interface.cpp rename to test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp diff --git a/test/grouped_gemm/test_grouped_gemm_splitk.cpp b/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_splitk.cpp rename to test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt index 1b6decfed7..65f33da74d 100644 --- a/test/normalization_bwd_data/CMakeLists.txt +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_data) add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt index f3579aad08..afb78dc58e 100644 --- a/test/normalization_bwd_gamma_beta/CMakeLists.txt +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_gamma_beta) add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) diff --git a/test/permute_scale/CMakeLists.txt b/test/permute_scale/CMakeLists.txt index be6aaf94aa..d63cb79910 100644 --- a/test/permute_scale/CMakeLists.txt +++ b/test/permute_scale/CMakeLists.txt @@ -1,6 +1,4 @@ add_custom_target(test_permute) add_gtest_executable(test_permute_scale test_permute_scale.cpp) -if(result EQUAL 0) - target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) - add_dependencies(test_permute test_permute_scale) -endif() +target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) +add_dependencies(test_permute test_permute_scale) diff --git a/test/transpose/CMakeLists.txt b/test/transpose/CMakeLists.txt index 530cc9d72d..fb9379bea9 100644 --- a/test/transpose/CMakeLists.txt +++ b/test/transpose/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_transpose test_transpose.cpp) - target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) - set(target 1) - endif() -endforeach() +add_gtest_executable(test_transpose test_transpose_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) +endif() diff --git a/test/transpose/test_transpose.cpp b/test/transpose/test_transpose_xdl.cpp similarity index 100% rename from test/transpose/test_transpose.cpp rename to test/transpose/test_transpose_xdl.cpp diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 383707828c..1eb6c35db2 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -12,10 +12,8 @@ add_dependencies(test_wrapper test_wrapper_copy) add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp) target_link_libraries(test_wrapper_partition PRIVATE utility) add_dependencies(test_wrapper test_wrapper_partition) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") - add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp) +add_gtest_executable(test_wrapper_gemm test_wrapper_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_wrapper_gemm PRIVATE utility) add_dependencies(test_wrapper test_wrapper_gemm) endif() diff --git a/test/wrapper/test_wrapper_gemm.cpp b/test/wrapper/test_wrapper_gemm_xdl.cpp similarity index 100% rename from test/wrapper/test_wrapper_gemm.cpp rename to test/wrapper/test_wrapper_gemm_xdl.cpp