mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Merge branch 'develop' into ck_tile/refactor
This commit is contained in:
@@ -145,6 +145,22 @@ if(GPU_TARGETS)
|
||||
else()
|
||||
message("Building CK for the following targets: ${AMDGPU_TARGETS}")
|
||||
endif()
|
||||
|
||||
if (GPU_TARGETS)
|
||||
if (GPU_TARGETS MATCHES "gfx9")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx11")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
else()
|
||||
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
|
||||
find_package(hip)
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
# SWDEV-413293 and https://reviews.llvm.org/D155213
|
||||
|
||||
@@ -1,27 +1,29 @@
|
||||
add_custom_target(client_gemm_fastgelu_examples)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_custom_target(client_gemm_fastgelu_examples)
|
||||
|
||||
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
|
||||
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
|
||||
client_gemm_fastgelu)
|
||||
|
||||
add_custom_target(client_gemm_fastgelu_generic_examples)
|
||||
add_custom_target(client_gemm_fastgelu_generic_examples)
|
||||
|
||||
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
|
||||
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
|
||||
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
|
||||
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
|
||||
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
|
||||
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
|
||||
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
|
||||
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
|
||||
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
|
||||
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
|
||||
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
|
||||
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_fused_attention fused_attention.cpp)
|
||||
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_fused_attention fused_attention.cpp)
|
||||
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
|
||||
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
|
||||
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_quantization gemm_quantization.cpp)
|
||||
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
add_executable(client_gemm_quantization gemm_quantization.cpp)
|
||||
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
|
||||
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
|
||||
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
|
||||
|
||||
target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
|
||||
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
|
||||
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES))
|
||||
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
|
||||
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
# Fwd scaleadd scaleadd relu
|
||||
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32
|
||||
grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp)
|
||||
@@ -46,3 +47,4 @@ target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_ke
|
||||
add_executable(client_grouped_convnd_bwd_data_scale_fp16
|
||||
grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp)
|
||||
target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
@@ -2,9 +2,7 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap
|
||||
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
|
||||
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
|
||||
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
|
||||
GPU_TARGETS MATCHES "gfx942")
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp)
|
||||
target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp)
|
||||
|
||||
@@ -48,7 +48,10 @@ else()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations)
|
||||
find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
find_package(composable_kernel COMPONENTS device_contraction_operations)
|
||||
endif()
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
|
||||
@@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
|
||||
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
|
||||
@@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
# FIXME: re-enable this example as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
@@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
|
||||
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
|
||||
@@ -1,20 +1,3 @@
|
||||
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
|
||||
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
|
||||
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
|
||||
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,29 +1,20 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_gemm_add_add_fastgelu_xdl)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
|
||||
add_custom_target(example_gemm_add_add_fastgelu_xdl)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
|
||||
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
|
||||
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(gpu_list "")
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_fwd_reduce_xdl)
|
||||
add_custom_target(example_convnd_fwd_reduce_xdl)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
|
||||
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
|
||||
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
|
||||
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
# dlops
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
# xdlops
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
|
||||
|
||||
@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
|
||||
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
|
||||
@@ -36,7 +36,7 @@ using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F32;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(256 + 256 * i);
|
||||
problem_size.Ns.push_back(256);
|
||||
problem_size.Ks.push_back(128);
|
||||
problem_size.Ms.push_back(128 + rand() % 128);
|
||||
problem_size.Ns.push_back(1024);
|
||||
problem_size.Ks.push_back(1024);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
|
||||
@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ADataType = F16;
|
||||
using BDataType = F8;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
|
||||
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -1,48 +1,41 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_gemm_reduce_xdl)
|
||||
add_custom_target(example_gemm_reduce_xdl_max)
|
||||
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
|
||||
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
|
||||
add_custom_target(example_gemm_reduce_xdl)
|
||||
add_custom_target(example_gemm_reduce_xdl_max)
|
||||
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
|
||||
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
|
||||
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
|
||||
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
|
||||
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
|
||||
|
||||
add_example_dependencies(example_gemm_reduce_xdl
|
||||
example_gemm_reduce_xdl_mean_meansquare
|
||||
example_gemm_reduce_xdl_max
|
||||
example_gemm_add_add_mean_meansquare_xdl)
|
||||
add_example_dependencies(example_gemm_reduce_xdl
|
||||
example_gemm_reduce_xdl_mean_meansquare
|
||||
example_gemm_reduce_xdl_max
|
||||
example_gemm_add_add_mean_meansquare_xdl)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
|
||||
add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
|
||||
endif()
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
|
||||
@@ -1,29 +1,15 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
set(target 1)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_custom_target(example_grouped_conv_bwd_weight_dl)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16)
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
|
||||
|
||||
@@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear)
|
||||
|
||||
# FP32
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16)
|
||||
|
||||
# FP64
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32)
|
||||
|
||||
# FP16
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32)
|
||||
|
||||
# BF16
|
||||
add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32)
|
||||
|
||||
add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp)
|
||||
add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
|
||||
add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32)
|
||||
|
||||
add_dependencies(example_contraction example_contraction_scale)
|
||||
add_dependencies(example_contraction example_contraction_bilinear)
|
||||
add_example_dependencies(example_contraction example_contraction_scale)
|
||||
add_example_dependencies(example_contraction example_contraction_bilinear)
|
||||
|
||||
@@ -1,5 +1,2 @@
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
|
||||
@@ -1,40 +1,23 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
|
||||
add_custom_target(example_grouped_conv_fwd_multiple_d)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_fwd_multiple_d)
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
endif() # USE_BITINT_EXTENSION_INT4
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
|
||||
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
endif() # USE_BITINT_EXTENSION_INT4
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
|
||||
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
|
||||
|
||||
add_custom_target(example_gemm_scale_softmax_gemm)
|
||||
|
||||
|
||||
@@ -1,32 +1,23 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_splitK_gemm_xdl)
|
||||
add_custom_target(example_splitK_gemm_xdl)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
|
||||
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
|
||||
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
|
||||
endif()
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
|
||||
endif()
|
||||
|
||||
@@ -1,27 +1,10 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
|
||||
|
||||
@@ -1,24 +1,17 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
|
||||
# Conv perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
|
||||
# Conv perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + relu perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + relu perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
# Conv perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
|
||||
# Conv perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + relu perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + relu perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
|
||||
@@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
|
||||
add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_permute elementwise_permute.cpp)
|
||||
if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942"))
|
||||
add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp)
|
||||
|
||||
140
example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp
Normal file
140
example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using BinaryAdd = ck::tensor_operation::element_wise::Add;
|
||||
// B = alpha * A0 * A0 + beta * A1 * A1
|
||||
using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise::
|
||||
BinaryWithUnaryCombinedOp<BinaryAdd, UnaryScaleSquare, UnaryScaleSquare>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType, ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
BinaryAddUnaryScaleSquare, // ElementwiseOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8, 8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[3]),
|
||||
1};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 2> as = {Tensor<ADataType>(ab_lengths, ab_strides),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides)};
|
||||
Tensor<ADataType>& a0 = as[0];
|
||||
Tensor<ADataType>& a1 = as[1];
|
||||
Tensor<BDataType> b(ab_lengths, ab_strides);
|
||||
float alpha = 3.f;
|
||||
float beta = 2.f;
|
||||
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0.mData.data());
|
||||
a1_device_buf.ToDevice(a1.mData.data());
|
||||
|
||||
std::array<const void*, 2> inputs = {a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
|
||||
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths,
|
||||
{ab_strides, ab_strides},
|
||||
{ab_strides},
|
||||
inputs,
|
||||
output,
|
||||
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
|
||||
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
|
||||
std::cout << "B (nchw): " << b.mDesc << std::endl;
|
||||
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<BDataType> host_b(ab_lengths, ab_strides);
|
||||
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as,
|
||||
host_b,
|
||||
BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance =
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
|
||||
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
|
||||
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
|
||||
{
|
||||
auto a_val = A_ncdhw(n, c, d, h, w);
|
||||
functor(B_ndhwc(n, d, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -51,32 +39,7 @@ int main()
|
||||
|
||||
std::vector<std::size_t> ncdhw = {16, 8, 8, 8, 8};
|
||||
std::vector<std::size_t> ndhwc = {16, 8, 8, 8, 8};
|
||||
Tensor<ADataType> a(ncdhw);
|
||||
Tensor<BDataType> b(ndhwc);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 5> ab_lengths;
|
||||
/**std::array<ck::index_t, 5> a_strides = {
|
||||
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
|
||||
static_cast<int>(ncdhw[2] * ncdhw[3] * ncdhw[4]),
|
||||
static_cast<int>(ncdhw[3] * ncdhw[4]),
|
||||
static_cast<int>(ncdhw[4]),
|
||||
1};
|
||||
std::array<ck::index_t, 5> b_strides = {
|
||||
static_cast<int>(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]),
|
||||
static_cast<int>(ndhwc[2] * ndhwc[3] * ndhwc[4]),
|
||||
1,
|
||||
static_cast<int>(ndhwc[3] * ndhwc[4]),
|
||||
static_cast<int>(ndhwc[4])};**/
|
||||
|
||||
std::array<ck::index_t, 5> a_strides = {
|
||||
static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
|
||||
@@ -93,6 +56,20 @@ int main()
|
||||
1};
|
||||
ck::ranges::copy(ncdhw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -126,10 +103,16 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(ndhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance =
|
||||
ck::Sequence<4>, // InScalarPerVectorSeq
|
||||
ck::Sequence<4>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
|
||||
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
|
||||
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
|
||||
{
|
||||
auto a_val = A_ncdhw(n, c, d, h, w);
|
||||
functor(B_ndhwc(n, d, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -59,10 +47,13 @@ int main()
|
||||
const int W = 5;
|
||||
const int D = 16;
|
||||
|
||||
std::vector<std::size_t> ncdhw = {N, C, D, H, W};
|
||||
std::vector<std::size_t> ndhwc = {N, D, H, W, C};
|
||||
Tensor<ADataType> a(ncdhw);
|
||||
Tensor<BDataType> b(ndhwc);
|
||||
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
|
||||
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
|
||||
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
@@ -74,10 +65,6 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
|
||||
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
|
||||
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -94,11 +81,12 @@ int main()
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4];
|
||||
std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] *
|
||||
ab_lengths[3] * ab_lengths[4];
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
|
||||
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
|
||||
(sizeof(ADataType) + sizeof(BDataType)) *
|
||||
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -111,10 +99,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(ndhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -55,18 +44,6 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
@@ -77,9 +54,22 @@ int main()
|
||||
1,
|
||||
static_cast<int>(nhwc[2] * nhwc[3]),
|
||||
static_cast<int>(nhwc[3])};
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -111,10 +101,16 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
@@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance =
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
const HostTensorA& A_nchw,
|
||||
const std::vector<std::size_t>& shape_nchw,
|
||||
Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < shape_nchw[0]; ++n)
|
||||
for(std::size_t c = 0; c < shape_nchw[1]; ++c)
|
||||
for(std::size_t h = 0; h < shape_nchw[2]; ++h)
|
||||
for(std::size_t w = 0; w < shape_nchw[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -54,13 +40,16 @@ int main()
|
||||
const int N = 120;
|
||||
const int C = 128;
|
||||
const int H = 32;
|
||||
const int W = 1024;
|
||||
const int W = 32;
|
||||
|
||||
std::vector<std::size_t> nchw = {N, C, H, W};
|
||||
std::vector<std::size_t> nhwc = {N, H, W, C};
|
||||
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
|
||||
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
|
||||
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
@@ -72,11 +61,6 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
|
||||
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
@@ -94,10 +78,11 @@ int main()
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
std::size_t flop =
|
||||
std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) *
|
||||
(ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -110,11 +95,16 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance =
|
||||
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
|
||||
host_b, a, nchw, PassThrough{});
|
||||
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@
|
||||
#include <random>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -21,11 +23,14 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
std::size_t N = A_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = A_nchw.mDesc.GetLengths()[1];
|
||||
std::size_t H = A_nchw.mDesc.GetLengths()[2];
|
||||
std::size_t W = A_nchw.mDesc.GetLengths()[3];
|
||||
for(std::size_t w = 0; w < W; ++w)
|
||||
for(std::size_t h = 0; h < H; ++h)
|
||||
for(std::size_t c = 0; c < C; ++c)
|
||||
for(std::size_t n = 0; n < N; ++n)
|
||||
{
|
||||
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
|
||||
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -60,8 +48,21 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 8, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
float scale = 1.f;
|
||||
auto i = 0;
|
||||
std::mt19937 gen(11939);
|
||||
@@ -84,22 +85,14 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -113,11 +106,10 @@ int main()
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype =
|
||||
(2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
@@ -129,10 +121,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -20,11 +22,14 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -55,18 +47,6 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
@@ -80,9 +60,29 @@ int main()
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -112,10 +112,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -20,11 +22,14 @@ using F32 = float;
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
std::size_t N = A_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = A_nchw.mDesc.GetLengths()[1];
|
||||
std::size_t H = A_nchw.mDesc.GetLengths()[2];
|
||||
std::size_t W = A_nchw.mDesc.GetLengths()[3];
|
||||
for(std::size_t w = 0; w < W; ++w)
|
||||
for(std::size_t h = 0; h < H; ++h)
|
||||
for(std::size_t c = 0; c < C; ++c)
|
||||
for(std::size_t n = 0; n < N; ++n)
|
||||
{
|
||||
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
|
||||
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
std::vector<std::size_t> nchw = {5, 4, 2, 3};
|
||||
std::vector<std::size_t> nhwc = {5, 2, 3, 4};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
std::vector<std::size_t> nchw = {16, 8, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 8};
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
|
||||
float scale = 1.f;
|
||||
auto i = 0;
|
||||
@@ -84,22 +86,14 @@ int main()
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {1,
|
||||
static_cast<int>(nchw[0]),
|
||||
static_cast<int>(nchw[0] * nchw[1]),
|
||||
static_cast<int>(nchw[0] * nchw[1] * nchw[2])};
|
||||
|
||||
std::array<ck::index_t, 4> b_strides = {1,
|
||||
static_cast<int>(nhwc[0] * nhwc[1] * nhwc[2]),
|
||||
static_cast<int>(nhwc[0]),
|
||||
static_cast<int>(nhwc[0] * nhwc[1])};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -129,10 +123,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -20,11 +22,14 @@ using F32 = float;
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
UnaryScaleSquare, // UnaryScaleSquare
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
@@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -55,18 +47,6 @@ int main()
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
@@ -80,9 +60,28 @@ int main()
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(ab_lengths, a_strides)};
|
||||
Tensor<ADataType>& a = as[0];
|
||||
Tensor<BDataType> b(ab_lengths, b_strides);
|
||||
float scale = 2.f;
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
auto argument =
|
||||
broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -112,10 +111,17 @@ int main()
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
Tensor<BDataType> host_b(ab_lengths, b_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
156
example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp
Normal file
156
example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::UnaryCombinedOp<UnarySquare, UnaryScale>;
|
||||
using BinaryAdd = ck::tensor_operation::element_wise::Add;
|
||||
// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2
|
||||
using TrinaryAddUnaryScaleSquare =
|
||||
ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp<BinaryAdd,
|
||||
BinaryAdd,
|
||||
UnaryScaleSquare,
|
||||
UnaryScaleSquare,
|
||||
UnaryScaleSquare>;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType, ADataType, ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
TrinaryAddUnaryScaleSquare, // ElementwiseOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
std::array<ck::index_t, 4> ab_lengths;
|
||||
std::array<ck::index_t, 4> ab_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[2] * nchw[3]),
|
||||
static_cast<int>(nchw[3]),
|
||||
1};
|
||||
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 3> as = {Tensor<ADataType>(ab_lengths, ab_strides),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides)};
|
||||
Tensor<ADataType>& a0 = as[0];
|
||||
Tensor<ADataType>& a1 = as[1];
|
||||
Tensor<ADataType>& a2 = as[2];
|
||||
Tensor<BDataType> b(ab_lengths, ab_strides);
|
||||
float alpha = 3.f;
|
||||
float beta = 2.f;
|
||||
float gamma = 4.f;
|
||||
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
a1.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
a2.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0.mData.data());
|
||||
a1_device_buf.ToDevice(a1.mData.data());
|
||||
a2_device_buf.ToDevice(a2.mData.data());
|
||||
|
||||
std::array<const void*, 3> inputs = {a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
a2_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}};
|
||||
auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}};
|
||||
auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths,
|
||||
{ab_strides, ab_strides, ab_strides},
|
||||
{ab_strides},
|
||||
inputs,
|
||||
output,
|
||||
TrinaryAddUnaryScaleSquare{
|
||||
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
std::cout << "A0 (nchw): " << a0.mDesc << std::endl;
|
||||
std::cout << "A1 (nchw): " << a1.mDesc << std::endl;
|
||||
std::cout << "A2 (nchw): " << a2.mDesc << std::endl;
|
||||
std::cout << "B (nchw): " << b.mDesc << std::endl;
|
||||
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<BDataType> host_b(ab_lengths, ab_strides);
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>;
|
||||
auto ref_elementwise = ReferenceElementwiseInstance{};
|
||||
auto ref_invoker = ref_elementwise.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_elementwise.MakeArgument(
|
||||
as,
|
||||
host_b,
|
||||
TrinaryAddUnaryScaleSquare{
|
||||
BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2});
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
const double threshold = std::pow(2, -10) * 2;
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
pass &= ck::utils::check_err(
|
||||
b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp)
|
||||
|
||||
@@ -1,15 +1,7 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_im2col_col2im)
|
||||
add_custom_target(example_im2col_col2im)
|
||||
|
||||
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
|
||||
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
|
||||
|
||||
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
|
||||
add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
|
||||
|
||||
@@ -2,16 +2,9 @@ add_subdirectory(binary)
|
||||
add_subdirectory(multi_AB)
|
||||
add_subdirectory(unary)
|
||||
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_xdl)
|
||||
# ScaleAdd ScaleAdd Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_custom_target(example_convnd_activ_xdl)
|
||||
# ScaleAdd ScaleAdd Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
|
||||
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_fpAintB_gemm_wmma)
|
||||
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
|
||||
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
|
||||
endif()
|
||||
add_custom_target(example_fpAintB_gemm_wmma)
|
||||
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
|
||||
add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
|
||||
|
||||
@@ -5,6 +5,12 @@ include_directories(BEFORE
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
|
||||
if(FILE_NAME)
|
||||
add_dependencies(EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endfunction(add_example_dependencies EXAMPLE_NAME)
|
||||
|
||||
function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
set(result 1)
|
||||
@@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
@@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
|
||||
@@ -45,6 +45,10 @@
|
||||
#endif
|
||||
|
||||
// define general macros for various architectures
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
@@ -62,8 +66,7 @@
|
||||
// buffer resource
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__)
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx103__)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
@@ -75,8 +78,7 @@
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
|
||||
#define CK_USE_AMD_V_MAC_F32
|
||||
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
|
||||
defined(__gfx94__) // for GPU code
|
||||
#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code
|
||||
#define CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||
#define CK_USE_AMD_V_DOT4_I32_I8
|
||||
@@ -89,7 +91,7 @@
|
||||
// MFMA instruction
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_MFMA
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#define CK_USE_AMD_MFMA
|
||||
#endif
|
||||
|
||||
@@ -120,7 +122,7 @@
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
|
||||
@@ -23,6 +23,7 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
GemmSpecialization GemmSpec,
|
||||
bool Zeroing,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -106,33 +107,63 @@ __global__ void
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
|
||||
|
||||
auto barrier_count_finished =
|
||||
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
auto barrier_count_finished =
|
||||
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
|
||||
GridwiseGemm::template RunWithZeroing<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
nullptr,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
}
|
||||
|
||||
id_off += grid_size_grp;
|
||||
id_local += grid_size_grp;
|
||||
@@ -193,8 +224,11 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeType = ADataType,
|
||||
typename ALDSType = ComputeType,
|
||||
typename BLDSType = ComputeType>
|
||||
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using AComputeType = ComputeType;
|
||||
using BComputeType = ComputeType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeType,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
LoopSched,
|
||||
PipelineVer,
|
||||
ALDSType,
|
||||
BLDSType>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
@@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
float ave_time = 0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
if(arg.k_batch_ == 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
false,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
reinterpret_cast<uint32_t*>(arg.p_workspace_),
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
nullptr,
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
true,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
reinterpret_cast<uint32_t*>(arg.p_workspace_),
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
|
||||
// in IsSupportedArgument function
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
|
||||
// enforced in IsSupportedArgument function
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
|
||||
bool supported = true;
|
||||
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by vector
|
||||
// load size.
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by
|
||||
// vector load size.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default)
|
||||
{
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
|
||||
// thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
|
||||
// layout, thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
|
||||
|
||||
@@ -92,6 +92,110 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
const Y x0_converted = type_convert<Y>(x0);
|
||||
const Y x1_converted = type_convert<Y>(x1);
|
||||
y = ck::math::max(x0_converted, x1_converted);
|
||||
}
|
||||
};
|
||||
|
||||
struct Min
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
const Y x0_converted = type_convert<Y>(x0);
|
||||
const Y x1_converted = type_convert<Y>(x1);
|
||||
y = ck::math::min(x0_converted, x1_converted);
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 * type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0 * x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0) * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x1);
|
||||
y = x0 * x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x0);
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp * x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x0 * x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = x0 * x1;
|
||||
};
|
||||
};
|
||||
|
||||
struct ScaleAdd
|
||||
{
|
||||
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
// y = UnaryOp0(UnaryOp1(...(x)))
|
||||
template <typename... UnaryOpsSet>
|
||||
struct UnaryCombinedOp
|
||||
{
|
||||
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// Execute first unary op to copy data to y
|
||||
unary_ops_.At(Number<0>{})(y, x);
|
||||
|
||||
static_for<1, Tuple<UnaryOpsSet...>::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); });
|
||||
};
|
||||
|
||||
Tuple<UnaryOpsSet...> unary_ops_;
|
||||
};
|
||||
|
||||
// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
|
||||
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
|
||||
struct BinaryWithUnaryCombinedOp
|
||||
{
|
||||
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
|
||||
UnaryOp0 unary_op0,
|
||||
UnaryOp1 unary_op1)
|
||||
: binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
Y unary_x0_tmp_result;
|
||||
Y unary_x1_tmp_result;
|
||||
unary_op0_(unary_x0_tmp_result, x0);
|
||||
unary_op1_(unary_x1_tmp_result, x1);
|
||||
binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result);
|
||||
};
|
||||
|
||||
private:
|
||||
BinaryOp binary_op_;
|
||||
UnaryOp0 unary_op0_;
|
||||
UnaryOp1 unary_op1_;
|
||||
};
|
||||
|
||||
// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
|
||||
template <typename BinaryOp0,
|
||||
typename BinaryOp1,
|
||||
typename UnaryOp0,
|
||||
typename UnaryOp1,
|
||||
typename UnaryOp2>
|
||||
struct TrinaryWithUnaryCombinedOp
|
||||
{
|
||||
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
|
||||
BinaryOp0 binary_op1,
|
||||
UnaryOp0 unary_op0,
|
||||
UnaryOp1 unary_op1,
|
||||
UnaryOp2 unary_op2)
|
||||
: binary_op0_(binary_op0),
|
||||
binary_op1_(binary_op1),
|
||||
unary_op0_(unary_op0),
|
||||
unary_op1_(unary_op1),
|
||||
unary_op2_(unary_op2)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Y, typename X0, typename X1, typename X2>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const
|
||||
{
|
||||
|
||||
Y unary_x0_tmp_result;
|
||||
Y unary_x1_tmp_result;
|
||||
Y unary_x2_tmp_result;
|
||||
unary_op0_(unary_x0_tmp_result, x0);
|
||||
unary_op1_(unary_x1_tmp_result, x1);
|
||||
unary_op2_(unary_x2_tmp_result, x2);
|
||||
binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result);
|
||||
binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result);
|
||||
};
|
||||
|
||||
private:
|
||||
BinaryOp0 binary_op0_{};
|
||||
BinaryOp1 binary_op1_{};
|
||||
UnaryOp0 unary_op0_{};
|
||||
UnaryOp1 unary_op1_{};
|
||||
UnaryOp2 unary_op2_{};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -12,10 +12,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
#if CK_WORKAROUND_SWDEV_383542
|
||||
extern "C" __device__ float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
@@ -449,11 +445,7 @@ struct FastGelu
|
||||
const float u = x * (c1 * x * x + c2);
|
||||
const float emu = __expf(u);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_383542
|
||||
y = x * __frcp_rn(1.f + emu);
|
||||
#else
|
||||
y = x * __ocml_native_recip_f32(1.f + emu);
|
||||
#endif
|
||||
y = x * ck::math::rcp(1.f + emu);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -559,6 +551,244 @@ struct TanH
|
||||
};
|
||||
};
|
||||
|
||||
struct ACos
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::acos(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Neg
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::neg(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ATan
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::atan(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::sin(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ASinH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::asinh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Cos
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::cos(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ACosH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::acosh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Tan
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::tan(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ATanH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::atanh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct SinH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::sinh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Ceil
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::ceil(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Exp
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::exp(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct CosH
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::cosh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::floor(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::log(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ASin
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::asin(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Rcp
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::rcp(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Swish
|
||||
{
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
@@ -118,8 +118,16 @@ struct GridwiseElementwise
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
|
||||
const index_t m1_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
|
||||
const auto thread_grid_offset =
|
||||
make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
const auto input_thread_grid_offset = generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
const auto output_thread_grid_offset = generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
// If src and dst have same vector dim, then:
|
||||
@@ -157,9 +165,9 @@ struct GridwiseElementwise
|
||||
uniform_sequence_gen_t<NumOutput, 1>,
|
||||
uniform_sequence_gen_t<NumInput, false>,
|
||||
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
input_thread_grid_offset,
|
||||
out_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
output_thread_grid_offset,
|
||||
elementwise_op};
|
||||
global_to_global_transfer.Run(
|
||||
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
|
||||
|
||||
@@ -31,7 +31,8 @@ namespace ck {
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeType,
|
||||
typename AComputeType,
|
||||
typename BComputeType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -71,7 +72,9 @@ template <typename ADataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer,
|
||||
typename ALDSType,
|
||||
typename BLDSType>
|
||||
struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ComputeType),
|
||||
return math::max(a_block_space_size_aligned * sizeof(ALDSType) +
|
||||
b_block_space_size_aligned * sizeof(BLDSType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t NumDTensor_,
|
||||
typename DsDataType_,
|
||||
bool Zeroing,
|
||||
typename AGridDesc_KBatch_AK0_M_AK1,
|
||||
typename BGridDesc_KBatch_BK0_N_BK1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ComputeType,
|
||||
ALDSType,
|
||||
decltype(a_grid_desc_kbatch_ak0_m_ak1),
|
||||
decltype(a_block_desc_kbatch_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
ComputeType,
|
||||
BLDSType,
|
||||
decltype(b_grid_desc_kbatch_bk0_n_bk1),
|
||||
decltype(b_block_desc_kbatch_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeType,
|
||||
ComputeType,
|
||||
ALDSType,
|
||||
BLDSType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
AComputeType,
|
||||
BComputeType>();
|
||||
|
||||
#if 1
|
||||
if(block_work_idx[I0] == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
const index_t numNThreads = NPerBlock / nThreadSize;
|
||||
const index_t numMThreads = BlockSize / numNThreads;
|
||||
const index_t mThreadSize = MPerBlock / numMThreads;
|
||||
|
||||
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
|
||||
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
|
||||
|
||||
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
EDataType,
|
||||
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
|
||||
true>
|
||||
e_thread_zero_buf;
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
EDataType,
|
||||
EDataType,
|
||||
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mThreadSize, 1, nThreadSize>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1],
|
||||
m_tid * mThreadSize,
|
||||
block_work_idx[I2],
|
||||
n_tid * nThreadSize),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
e_thread_zero_buf,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
if(block_work_idx[I0] == 0)
|
||||
{
|
||||
atomicAdd(barrier_count_finished, 1);
|
||||
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
const index_t numNThreads = NPerBlock / nThreadSize;
|
||||
const index_t numMThreads = BlockSize / numNThreads;
|
||||
const index_t mThreadSize = MPerBlock / numMThreads;
|
||||
|
||||
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
|
||||
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
|
||||
|
||||
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
EDataType,
|
||||
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
|
||||
true>
|
||||
e_thread_zero_buf;
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
EDataType,
|
||||
EDataType,
|
||||
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mThreadSize, 1, nThreadSize>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1],
|
||||
m_tid * mThreadSize,
|
||||
block_work_idx[I2],
|
||||
n_tid * nThreadSize),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
e_thread_zero_buf,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_buf);
|
||||
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(barrier_count_finished, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ALDSType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<BLDSType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
|
||||
@@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
@@ -951,18 +960,131 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
}
|
||||
});
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
|
||||
|
||||
if(k_id_finished_t == KBatch)
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
*barrier_count_finished = 0;
|
||||
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
|
||||
|
||||
if(k_id_finished_t == KBatch)
|
||||
{
|
||||
*barrier_count_finished = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_,
|
||||
const void* __restrict__ p_b_grid_,
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
uint32_t* barrier_count_finished,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const index_t KBatch,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
|
||||
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
|
||||
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
const auto a_grid_desc_kbatch_ak0_m_ak1 =
|
||||
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
|
||||
|
||||
const auto b_grid_desc_kbatch_bk0_n_bk1 =
|
||||
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
if(kbatch_id == KBatch - 1)
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
@@ -976,7 +1098,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
uint32_t* barrier_count_finished,
|
||||
uint32_t*,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
@@ -1028,49 +1150,22 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
if(kbatch_id == KBatch - 1)
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
nullptr,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,6 +14,10 @@
|
||||
namespace ck {
|
||||
namespace math {
|
||||
|
||||
#if CK_WORKAROUND_SWDEV_383542
|
||||
extern "C" __device__ float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
// math functions for the host, some are implemented by calling C++ std functions
|
||||
|
||||
static inline __host__ float abs(float x) { return std::abs(x); };
|
||||
@@ -111,6 +115,276 @@ inline __host__ double tanh<double>(double x)
|
||||
return std::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T acos(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::acosf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float acos<float>(float x)
|
||||
{
|
||||
return std::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double acos<double>(double x)
|
||||
{
|
||||
return std::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T neg(T x)
|
||||
{
|
||||
return ck::type_convert<T>(-(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T atan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::atanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float atan<float>(float x)
|
||||
{
|
||||
return std::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double atan<double>(double x)
|
||||
{
|
||||
return std::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T sin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::sinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float sin<float>(float x)
|
||||
{
|
||||
return std::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double sin<double>(double x)
|
||||
{
|
||||
return std::sin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T asin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::asinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float asin<float>(float x)
|
||||
{
|
||||
return std::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double asin<double>(double x)
|
||||
{
|
||||
return std::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T asinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::asinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float asinh<float>(float x)
|
||||
{
|
||||
return std::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double asinh<double>(double x)
|
||||
{
|
||||
return std::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T cos(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::cosf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float cos<float>(float x)
|
||||
{
|
||||
return std::cosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double cos<double>(double x)
|
||||
{
|
||||
return std::cos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T acosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::acoshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float acosh<float>(float x)
|
||||
{
|
||||
return std::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double acosh<double>(double x)
|
||||
{
|
||||
return std::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T tan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::tanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float tan<float>(float x)
|
||||
{
|
||||
return std::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double tan<double>(double x)
|
||||
{
|
||||
return std::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T atanh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::atanhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float atanh<float>(float x)
|
||||
{
|
||||
return std::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double atanh<double>(double x)
|
||||
{
|
||||
return std::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T sinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::sinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float sinh<float>(float x)
|
||||
{
|
||||
return std::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double sinh<double>(double x)
|
||||
{
|
||||
return std::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T ceil(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::ceilf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float ceil<float>(float x)
|
||||
{
|
||||
return std::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double ceil<double>(double x)
|
||||
{
|
||||
return std::ceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T cosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::coshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float cosh<float>(float x)
|
||||
{
|
||||
return std::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double cosh<double>(double x)
|
||||
{
|
||||
return std::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T floor(T x)
|
||||
{
|
||||
return ck::type_convert<T>(std::floorf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ float floor<float>(float x)
|
||||
{
|
||||
return std::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __host__ double floor<double>(double x)
|
||||
{
|
||||
return std::floor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T rcp(T x)
|
||||
{
|
||||
return ck::type_convert<T>(1.f / ck::type_convert<float>(x));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __host__ T exp(T x)
|
||||
{
|
||||
@@ -282,6 +556,286 @@ inline __device__ double tanh<double>(double x)
|
||||
return ::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T acos(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::acosf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float acos<float>(float x)
|
||||
{
|
||||
return ::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double acos<double>(double x)
|
||||
{
|
||||
return ::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T neg(T x)
|
||||
{
|
||||
return ck::type_convert<T>(-(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t neg<half_t>(half_t x)
|
||||
{
|
||||
return __hneg(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T atan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::atanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float atan<float>(float x)
|
||||
{
|
||||
return ::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double atan<double>(double x)
|
||||
{
|
||||
return ::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T sin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::sinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float sin<float>(float x)
|
||||
{
|
||||
return ::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double sin<double>(double x)
|
||||
{
|
||||
return ::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t sin<half_t>(half_t x)
|
||||
{
|
||||
return ::hsin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T asin(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::asinf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float asin<float>(float x)
|
||||
{
|
||||
return ::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double asin<double>(double x)
|
||||
{
|
||||
return ::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T asinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::asinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float asinh<float>(float x)
|
||||
{
|
||||
return ::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double asinh<double>(double x)
|
||||
{
|
||||
return ::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T acosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::acoshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float acosh<float>(float x)
|
||||
{
|
||||
return ::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double acosh<double>(double x)
|
||||
{
|
||||
return ::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T tan(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::tanf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float tan<float>(float x)
|
||||
{
|
||||
return ::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double tan<double>(double x)
|
||||
{
|
||||
return ::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T atanh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::atanhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float atanh<float>(float x)
|
||||
{
|
||||
return ::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double atanh<double>(double x)
|
||||
{
|
||||
return ::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T sinh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::sinhf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float sinh<float>(float x)
|
||||
{
|
||||
return ::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double sinh<double>(double x)
|
||||
{
|
||||
return ::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T ceil(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::ceilf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float ceil<float>(float x)
|
||||
{
|
||||
return ::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double ceil<double>(double x)
|
||||
{
|
||||
return ::ceil(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t ceil<half_t>(half_t x)
|
||||
{
|
||||
return ::hceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T cosh(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::coshf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float cosh<float>(float x)
|
||||
{
|
||||
return ::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double cosh<double>(double x)
|
||||
{
|
||||
return ::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T floor(T x)
|
||||
{
|
||||
return ck::type_convert<T>(::floorf(ck::type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ float floor<float>(float x)
|
||||
{
|
||||
return ::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ double floor<double>(double x)
|
||||
{
|
||||
return ::floor(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ half_t floor<half_t>(half_t x)
|
||||
{
|
||||
return ::hfloor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T rcp(T x)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_383542
|
||||
return __frcp_rn(x);
|
||||
#else
|
||||
return __ocml_native_recip_f32(x);
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T exp(T x)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <index_t NumATensors, typename ADataType, typename BDataType, typename ElementOp>
|
||||
struct ReferenceElementwise : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const std::array<Tensor<ADataType>, NumATensors>& a_tensors,
|
||||
Tensor<BDataType>& b_tensor,
|
||||
ElementOp element_op)
|
||||
: a_tensors_{a_tensors}, b_tensor_{b_tensor}, element_op_{element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const std::array<Tensor<ADataType>, NumATensors>& a_tensors_;
|
||||
Tensor<BDataType>& b_tensor_;
|
||||
ElementOp element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceElementwise::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if constexpr(NumATensors == 1)
|
||||
{
|
||||
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
|
||||
arg.element_op_(self(idx), arg.a_tensors_[0](idx));
|
||||
});
|
||||
}
|
||||
else if constexpr(NumATensors == 2)
|
||||
{
|
||||
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
|
||||
arg.element_op_(self(idx), arg.a_tensors_[0](idx), arg.a_tensors_[1](idx));
|
||||
});
|
||||
}
|
||||
else if constexpr(NumATensors == 3)
|
||||
{
|
||||
arg.b_tensor_.ForEach([&](auto& self, auto idx) {
|
||||
arg.element_op_(self(idx),
|
||||
arg.a_tensors_[0](idx),
|
||||
arg.a_tensors_[1](idx),
|
||||
arg.a_tensors_[2](idx));
|
||||
});
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const std::array<Tensor<ADataType>, NumATensors>& a_tensors,
|
||||
Tensor<BDataType>& b_tensor,
|
||||
ElementOp element_op)
|
||||
{
|
||||
return Argument{a_tensors, b_tensor, element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceElementwise"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -12,398 +12,21 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
#include "gemm_dl.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "gemm_wmma.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_XDL
|
||||
#include "gemm_xdl.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
|
||||
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -435,16 +58,137 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, int8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // DL_KERNELS
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -452,10 +196,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -463,10 +203,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -474,10 +210,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -490,57 +222,25 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
@@ -69,7 +69,7 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
|
||||
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP32)
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,238 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
|
||||
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,439 +10,18 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_backward_data_xdl.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "grouped_convolution_backward_data_wmma.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
@@ -488,9 +67,10 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
@@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
@@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
@@ -593,35 +142,144 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
|
||||
is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
op_ptrs);
|
||||
@@ -638,46 +296,16 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
else if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
|
||||
is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
op_ptrs);
|
||||
@@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,216 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
GKXC,
|
||||
NWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
GKXC,
|
||||
NWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
GKXC,
|
||||
NWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,228 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,480 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,357 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv1d forward, GNWC/GKXC/GNWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
F8>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF8,
|
||||
BF8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -36,12 +36,27 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
# Do not build DL instances if DL_KERNELS macro is not set
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Do not build XDL instances if gfx9 targets are not on the target list
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Do not build WMMA instances if gfx11 targets are not on the target list
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
@@ -124,6 +139,26 @@ FOREACH(subdir_path ${dir_list})
|
||||
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
||||
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11"))
|
||||
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
||||
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
||||
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
||||
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if((add_inst EQUAL 1))
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(BATCHED_GEMM_INSTANCES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_batched_gemm_bias_permute_instance
|
||||
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_batched_gemm_gemm_instance
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_batched_gemm_reduce_instance
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_instance
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
|
||||
|
||||
# FP32
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(DEVICE_CONTRACTION_SCALE_INSTANCES)
|
||||
|
||||
# FP32
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_conv1d_bwd_data_instance
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp
|
||||
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_AND_DL_KERNELS
|
||||
set(CONV2D_BWD_DATA_INSTANCES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(DEVICE_CONV2D_FWD_INSTANCES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_conv2d_fwd_bias_relu_instance
|
||||
device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_conv2d_fwd_bias_relu_add_instance
|
||||
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_conv3d_bwd_data_instance
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp
|
||||
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_gemm_add_instance
|
||||
device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user