mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 5 (#2725)
* Enable xdl in gfx11 & gfx12
* update cmake file
* fix all instance build (cmake)
* fix batched_gemm_gemm(cmake)
* rebase cmake files
* fix cmake build error
* remve CK_ENABLE_DYNAMIC_WARP_SIZE
* update cmake build error2
* fix gfx11 build
CK_USE_XDL is enabled on gfx11 and gfx12
* fix gfx10 build
* fix gfx11 error
---------
Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
[ROCm/composable_kernel commit: f22740df82]
This commit is contained in:
@@ -220,7 +220,7 @@ rocm_check_target_ids(SUPPORTED_GPU_TARGETS
|
||||
|
||||
message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
message(STATUS "Enabling XDL instances")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
|
||||
@@ -48,7 +48,7 @@ else()
|
||||
endif()
|
||||
|
||||
if (GPU_TARGETS)
|
||||
if (GPU_TARGETS MATCHES "gfx9")
|
||||
if (GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
|
||||
@@ -44,8 +44,7 @@ list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllv
|
||||
example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -89,7 +88,14 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_xdl_bf16_streamk_v3 gemm_xdl_bf16_streamk_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_streamk_v3)
|
||||
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp)
|
||||
|
||||
@@ -53,7 +53,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device::
|
||||
//############################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//############################################| | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
|
||||
DeviceGroupedContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
// clang-format on
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
|
||||
@@ -11,6 +11,6 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95")
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_binary_xdl)
|
||||
# Bilinear residual
|
||||
add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_fwd_xdl_bilinear_residual_fp16)
|
||||
add_example_executable(example_convnd_bwd_data_xdl_bilinear_residual_fp16 convnd_bwd_data_xdl_bilinear_residual_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_data_xdl_bilinear_residual_fp16)
|
||||
add_example_executable(example_convnd_bwd_weight_xdl_bilinear_residual_fp16 convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_weight_xdl_bilinear_residual_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_custom_target(example_convnd_activ_binary_xdl)
|
||||
# Bilinear residual
|
||||
add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_fwd_xdl_bilinear_residual_fp16)
|
||||
add_example_executable(example_convnd_bwd_data_xdl_bilinear_residual_fp16 convnd_bwd_data_xdl_bilinear_residual_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_data_xdl_bilinear_residual_fp16)
|
||||
add_example_executable(example_convnd_bwd_weight_xdl_bilinear_residual_fp16 convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_weight_xdl_bilinear_residual_fp16)
|
||||
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_xdl_convinvscale)
|
||||
add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_convnd_activ_xdl_convinvscale)
|
||||
add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8)
|
||||
endif()
|
||||
@@ -1,20 +1,14 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_xdl_convscale)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8 )
|
||||
if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_convnd_activ_xdl_convscale)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8 )
|
||||
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_bf8 convnd_fwd_xdl_convscale_bf8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_bf8 convnd_fwd_xdl_convscale_bf8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_xdl_convscale_add)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8 )
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_convnd_activ_xdl_convscale_add)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8)
|
||||
endif()
|
||||
@@ -1,14 +1,8 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_xdl_convscale_reduce)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8)
|
||||
if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_convnd_activ_xdl_convscale_reduce)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8)
|
||||
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8)
|
||||
endif()
|
||||
@@ -1,11 +1,5 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_xdl_convscale_relu)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8 )
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if (NOT GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_convnd_activ_xdl_convscale_relu)
|
||||
add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8)
|
||||
endif()
|
||||
|
||||
@@ -1,45 +1,37 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_dynamic_unary_xdl)
|
||||
# Sigmoid
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_sigmoid_fp16 convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_sigmoid_fp16)
|
||||
# Tanh
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_tanh_fp16 convnd_fwd_xdl_dynamic_tanh_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_tanh_fp16)
|
||||
# Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_relu_fp16 convnd_fwd_xdl_dynamic_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_relu_fp16)
|
||||
# SoftRelu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_softrelu_fp16 convnd_fwd_xdl_dynamic_softrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_softrelu_fp16)
|
||||
# Abs
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_abs_fp16 convnd_fwd_xdl_dynamic_abs_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_abs_fp16)
|
||||
# Pow
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_pow_fp16 convnd_fwd_xdl_dynamic_pow_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_pow_fp16)
|
||||
# Clipped Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_clippedrelu_fp16 convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_clippedrelu_fp16)
|
||||
# Leaky Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_leakyrelu_fp16 convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_leakyrelu_fp16)
|
||||
# Elu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_elu_fp16 convnd_fwd_xdl_dynamic_elu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_elu_fp16)
|
||||
# Swish
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_swish_fp16 convnd_fwd_xdl_dynamic_swish_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_swish_fp16)
|
||||
# PassThrough
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fwd_xdl_dynamic_passthrough_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16)
|
||||
# Logistic
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_custom_target(example_convnd_activ_dynamic_unary_xdl)
|
||||
# Sigmoid
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_sigmoid_fp16 convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_sigmoid_fp16)
|
||||
# Tanh
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_tanh_fp16 convnd_fwd_xdl_dynamic_tanh_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_tanh_fp16)
|
||||
# Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_relu_fp16 convnd_fwd_xdl_dynamic_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_relu_fp16)
|
||||
# SoftRelu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_softrelu_fp16 convnd_fwd_xdl_dynamic_softrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_softrelu_fp16)
|
||||
# Abs
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_abs_fp16 convnd_fwd_xdl_dynamic_abs_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_abs_fp16)
|
||||
# Pow
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_pow_fp16 convnd_fwd_xdl_dynamic_pow_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_pow_fp16)
|
||||
# Clipped Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_clippedrelu_fp16 convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_clippedrelu_fp16)
|
||||
# Leaky Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_leakyrelu_fp16 convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_leakyrelu_fp16)
|
||||
# Elu
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_elu_fp16 convnd_fwd_xdl_dynamic_elu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_elu_fp16)
|
||||
# Swish
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_swish_fp16 convnd_fwd_xdl_dynamic_swish_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_swish_fp16)
|
||||
# PassThrough
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fwd_xdl_dynamic_passthrough_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16)
|
||||
# Logistic
|
||||
add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16)
|
||||
@@ -1,17 +1,10 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_multi_ab_xdl)
|
||||
# ScaleAdd on A and B
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 conv_fwd_xdl_scaleadd_ab_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp16)
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 conv_fwd_xdl_scaleadd_ab_fp32.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp32)
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scaleadd_ab_bf16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16)
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_custom_target(example_convnd_activ_multi_ab_xdl)
|
||||
# ScaleAdd on A and B
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 conv_fwd_xdl_scaleadd_ab_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp16)
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 conv_fwd_xdl_scaleadd_ab_fp32.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp32)
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scaleadd_ab_bf16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16)
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8)
|
||||
|
||||
@@ -1,45 +1,37 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_activ_unary_xdl)
|
||||
# Sigmoid
|
||||
add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16)
|
||||
# Tanh
|
||||
add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16)
|
||||
# Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16)
|
||||
# SoftRelu
|
||||
add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16)
|
||||
# Abs
|
||||
add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16)
|
||||
# Pow
|
||||
add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16)
|
||||
# Clipped Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
|
||||
# Leaky Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
|
||||
# Elu
|
||||
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16)
|
||||
# Swish
|
||||
add_example_executable(example_convnd_fwd_xdl_swish_fp16 convnd_fwd_xdl_swish_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_swish_fp16)
|
||||
# PassThrough
|
||||
add_example_executable(example_convnd_fwd_xdl_passthrough_fp16 convnd_fwd_xdl_passthrough_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_passthrough_fp16)
|
||||
# Logistic
|
||||
add_example_executable(example_convnd_fwd_xdl_logistic_fp16 convnd_fwd_xdl_logistic_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_logistic_fp16)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_custom_target(example_convnd_activ_unary_xdl)
|
||||
# Sigmoid
|
||||
add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16)
|
||||
# Tanh
|
||||
add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16)
|
||||
# Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16)
|
||||
# SoftRelu
|
||||
add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16)
|
||||
# Abs
|
||||
add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16)
|
||||
# Pow
|
||||
add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16)
|
||||
# Clipped Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
|
||||
# Leaky Relu
|
||||
add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
|
||||
# Elu
|
||||
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16)
|
||||
# Swish
|
||||
add_example_executable(example_convnd_fwd_xdl_swish_fp16 convnd_fwd_xdl_swish_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_swish_fp16)
|
||||
# PassThrough
|
||||
add_example_executable(example_convnd_fwd_xdl_passthrough_fp16 convnd_fwd_xdl_passthrough_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_passthrough_fp16)
|
||||
# Logistic
|
||||
add_example_executable(example_convnd_fwd_xdl_logistic_fp16 convnd_fwd_xdl_logistic_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_logistic_fp16)
|
||||
|
||||
@@ -16,7 +16,7 @@ add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx11-generic gfx12-generic)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -69,7 +69,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl")
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
@@ -93,8 +93,8 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
message(DEBUG "removing bf8 example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95")
|
||||
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95 and gfx12
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95" AND NOT EX_TARGETS MATCHES "gfx12")
|
||||
if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)")
|
||||
message(DEBUG "Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
@@ -109,14 +109,14 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
if(FILE_NAME)
|
||||
if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx10-3-generic)
|
||||
elseif(source_name_list MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950
|
||||
elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 gfx950 and rdna3/4
|
||||
message(DEBUG "trimming targets for ${FILE_NAME}")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx10-3-generic)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
@@ -192,7 +192,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl")
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
@@ -206,7 +206,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(source_name_list MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx10-3-generic)
|
||||
elseif(source_name_list MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
endif()
|
||||
|
||||
@@ -68,11 +68,8 @@ inline bool is_gfx11_supported()
|
||||
inline bool is_xdl_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"
|
||||
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|
||||
|| is_gfx12_supported() || is_gfx11_supported()
|
||||
#endif
|
||||
;
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
||||
is_gfx12_supported() || is_gfx11_supported();
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
|
||||
@@ -83,7 +80,6 @@ inline bool is_xdl_wmma_supported()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|
||||
else if(is_gfx12_supported() || is_gfx11_supported())
|
||||
{
|
||||
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
|
||||
@@ -96,7 +92,6 @@ inline bool is_xdl_wmma_supported()
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|
||||
__device__ constexpr index_t get_warp_size()
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -38,16 +37,6 @@ inline __host__ index_t get_warp_size()
|
||||
#endif
|
||||
return 64;
|
||||
}
|
||||
#else
|
||||
__host__ __device__ constexpr index_t get_warp_size()
|
||||
{
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
|
||||
@@ -359,7 +359,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
@@ -369,6 +369,8 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
@@ -381,6 +383,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, x.template get_as<float>()[I2]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, x.template get_as<float>()[I3]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
|
||||
@@ -638,7 +638,7 @@ struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP8
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94)
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
|
||||
@@ -54,7 +54,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Do not build XDL instances if gfx9 targets are not on the target list
|
||||
if(NOT INST_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl")
|
||||
if(NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -73,13 +73,13 @@ function(add_instance_library INSTANCE_NAME)
|
||||
message(DEBUG "removing mha instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
|
||||
# Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94, gfx95 and gfx12
|
||||
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_")
|
||||
message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "_f8_")
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "_f8_")
|
||||
message(DEBUG "removing gemm_universal_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -89,8 +89,8 @@ function(add_instance_library INSTANCE_NAME)
|
||||
message(DEBUG "removing gemm_universal_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Do not build gemm_universal_preshuffle_f8 for any targets except gfx94
|
||||
if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950") AND (source_name MATCHES "gemm_universal_preshuffle" OR source_name MATCHES "gemm_xdl_universal_preshuffle") AND (source_name MATCHES "_f8_f8_f16" OR source_name MATCHES "_f8_f8_bf16"))
|
||||
# Do not build gemm_universal_preshuffle_f8 for any targets except gfx94, gfx95 and gfx12
|
||||
if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_universal_preshuffle" OR source_name MATCHES "gemm_xdl_universal_preshuffle") AND (source_name MATCHES "_f8_f8_f16" OR source_name MATCHES "_f8_f8_bf16"))
|
||||
message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -106,7 +106,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
if(source_name MATCHES "_xdl")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx10-3-generic)
|
||||
elseif(source_name MATCHES "_wmma")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(source_name MATCHES "mha")
|
||||
@@ -120,29 +120,29 @@ function(add_instance_library INSTANCE_NAME)
|
||||
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 and gfx1200/gfx1201
|
||||
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_xdl_universal_preshuffle" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
else()
|
||||
if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_xdl_universal_preshuffle" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "f8")
|
||||
@@ -266,7 +266,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
message(DEBUG "Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12"))
|
||||
message(DEBUG "Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
@@ -278,33 +278,36 @@ FOREACH(subdir_path ${dir_list})
|
||||
message(DEBUG "Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12"))
|
||||
message(DEBUG "Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12"))
|
||||
message(DEBUG "Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
||||
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12") AND (NOT DEFINED DL_KERNELS))
|
||||
message(DEBUG "Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
message(DEBUG "Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
|
||||
if("${cmake_instance}" MATCHES "gemm_multiply_multiply_wp" AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12"))
|
||||
message(DEBUG "Found gemm_multiply_multiply_wp instances, but gfx94/gfx95/gfx12 not on the target list. Skipping. ${cmake_instance}")
|
||||
set(add_inst 0)
|
||||
elseif("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12"))
|
||||
message(DEBUG "Found gemm_multiply_multiply instances, but gfx94/gfx95/gfx11/gfx12 not on the target list. Skipping. ${cmake_instance}")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
message(DEBUG "Found gemm_universal_preshuffle_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "gemm_xdl_universal_preshuffle" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
if(("${cmake_instance}" MATCHES "gemm_xdl_universal_preshuffle" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
message(DEBUG "Found gemm_xdl_universal_preshuffle_f8_f8_bf16 instances, but gfx94/gfx95 not on the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if ("${cmake_instance}" MATCHES "gemm_bilinear")
|
||||
set(add_inst 0)
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9") AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES))
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES))
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if((SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))
|
||||
|
||||
@@ -153,7 +153,7 @@ list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
@@ -173,11 +173,13 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
|
||||
@@ -92,7 +92,8 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
#if defined(CK_USE_XDL) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -166,8 +167,8 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
#if defined(CK_USE_XDL) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(
|
||||
|
||||
@@ -103,7 +103,8 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
#endif
|
||||
@@ -167,7 +168,8 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
@@ -201,7 +203,8 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
|
||||
|
||||
@@ -104,7 +104,8 @@ int profile_gemm_universal_preshuffle(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
@@ -163,7 +164,8 @@ int profile_gemm_universal_preshuffle(int argc, char* argv[])
|
||||
{
|
||||
return profile(F8{}, F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
|
||||
@@ -94,7 +94,7 @@ function(add_test_executable TEST_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
|
||||
if(NOT TEST_TARGETS MATCHES "gfx9" AND NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "xdl")
|
||||
message(DEBUG "removing xdl test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -108,7 +108,7 @@ function(add_test_executable TEST_NAME)
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx10-3-generic)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(ARGN MATCHES "_smfmac")
|
||||
@@ -179,7 +179,7 @@ function(add_gtest_executable TEST_NAME)
|
||||
endforeach()
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
|
||||
if(NOT TEST_TARGETS MATCHES "gfx9" AND NOT TEST_TARGETS MATCHES "gfx1[12]" AND source MATCHES "xdl")
|
||||
message(DEBUG "removing xdl test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -202,7 +202,7 @@ function(add_gtest_executable TEST_NAME)
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx10-3-generic)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(ARGN MATCHES "_smfmac")
|
||||
|
||||
127
test/ck_tile/atomic_add_op/test_atomic.cpp
Executable file → Normal file
127
test/ck_tile/atomic_add_op/test_atomic.cpp
Executable file → Normal file
@@ -21,44 +21,22 @@ struct AtomicKernelParam
|
||||
template <typename DataType_, ck_tile::index_t multiple_>
|
||||
class TestAtomicKernel : public ::testing::TestWithParam<std::tuple<int, int>>
|
||||
{
|
||||
struct AtomicKernelWaveSize64
|
||||
struct AtomicKernelWaveSize
|
||||
{
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 8>;
|
||||
using WaveTile = ck_tile::sequence<64, 8>;
|
||||
static constexpr ck_tile::index_t kBlockSize = 128; // 2 waves * 64 lanes
|
||||
};
|
||||
|
||||
struct AtomicKernelWaveSize32
|
||||
{
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<64, 8>;
|
||||
using WaveTile = ck_tile::sequence<32, 8>; // 32*2 == 64
|
||||
static constexpr ck_tile::index_t kBlockSize = 64; // 2 waves * 32 lanes
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 8>;
|
||||
using WaveTile = ck_tile::sequence<64, 8>;
|
||||
};
|
||||
|
||||
template <typename Config>
|
||||
void RunTestImpl_(const AtomicKernelParam& params, int require_warp_size, const char* tag)
|
||||
void RunTestImpl_(const AtomicKernelParam& params)
|
||||
{
|
||||
// Device capability check & skip if wavesize mismatches
|
||||
int dev = 0;
|
||||
hipDeviceProp_t prop{};
|
||||
if(hipGetDevice(&dev) != hipSuccess || hipGetDeviceProperties(&prop, dev) != hipSuccess)
|
||||
{
|
||||
GTEST_SKIP() << "[" << tag << "] hipGetDeviceProperties failed; skipping.";
|
||||
}
|
||||
if(prop.warpSize != require_warp_size)
|
||||
{
|
||||
GTEST_SKIP() << "[" << tag << "] Device warpSize=" << prop.warpSize << " (requires "
|
||||
<< require_warp_size << "); skipping.";
|
||||
}
|
||||
|
||||
using XDataType = DataType_;
|
||||
|
||||
const ck_tile::index_t m = params.m;
|
||||
const ck_tile::index_t n = params.n;
|
||||
|
||||
std::cout << "[" << tag << "] Input Tensor Dimensions: " << m << ", " << n << std::endl;
|
||||
std::cout << "Input Tensor Dimensions: " << m << ", " << n << std::endl;
|
||||
|
||||
constexpr int dword_bytes = 4;
|
||||
const int base_vec = dword_bytes / static_cast<int>(sizeof(XDataType));
|
||||
@@ -90,7 +68,7 @@ class TestAtomicKernel : public ::testing::TestWithParam<std::tuple<int, int>>
|
||||
WaveTile::at(ck_tile::number<1>{}) * BlockWaves::at(ck_tile::number<1>{}),
|
||||
"BlockTile.N must equal WaveTile.N * BlockWaves.N");
|
||||
|
||||
std::cout << "[" << tag << "] Vector per thread = " << vec
|
||||
std::cout << "Vector per thread = " << vec
|
||||
<< " BlockWaves=" << BlockWaves::at(ck_tile::number<0>{}) << "x"
|
||||
<< BlockWaves::at(ck_tile::number<1>{})
|
||||
<< " WaveTile=" << WaveTile::at(ck_tile::number<0>{}) << "x"
|
||||
@@ -105,7 +83,7 @@ class TestAtomicKernel : public ::testing::TestWithParam<std::tuple<int, int>>
|
||||
using Problem = ck_tile::AtomicKernelProblem<XDataType, Shape>;
|
||||
using Kernel = ck_tile::AtomicKernel<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = Config::kBlockSize;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
(void)hipGetLastError(); // clear sticky
|
||||
@@ -121,9 +99,8 @@ class TestAtomicKernel : public ::testing::TestWithParam<std::tuple<int, int>>
|
||||
n));
|
||||
|
||||
ASSERT_EQ(hipPeekAtLastError(), hipSuccess)
|
||||
<< "[" << tag << "] hipPeekAtLastError: " << hipGetErrorString(hipGetLastError());
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess)
|
||||
<< "[" << tag << "] hipDeviceSynchronize failed";
|
||||
<< "hipPeekAtLastError: " << hipGetErrorString(hipGetLastError());
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess) << "hipDeviceSynchronize failed";
|
||||
|
||||
// host reference computation
|
||||
x_dev_input.FromDevice(x_host_dev.mData.data());
|
||||
@@ -136,17 +113,7 @@ class TestAtomicKernel : public ::testing::TestWithParam<std::tuple<int, int>>
|
||||
}
|
||||
|
||||
protected:
|
||||
// WaveSize = 64 path
|
||||
void RunTest(const AtomicKernelParam& params)
|
||||
{
|
||||
RunTestImpl_<AtomicKernelWaveSize64>(params, /*require_warp_size=*/64, "WS64");
|
||||
}
|
||||
|
||||
// WaveSize = 32 path
|
||||
void RunTestWave32(const AtomicKernelParam& params)
|
||||
{
|
||||
RunTestImpl_<AtomicKernelWaveSize32>(params, /*require_warp_size=*/32, "WS32");
|
||||
}
|
||||
void RunTest(const AtomicKernelParam& params) { RunTestImpl_<AtomicKernelWaveSize>(params); }
|
||||
};
|
||||
|
||||
class TestAtomicKernelHalf_1 : public TestAtomicKernel<ck_tile::half_t, 1>
|
||||
@@ -189,10 +156,6 @@ class TestAtomicKernelFloat_4 : public TestAtomicKernel<float, 4>
|
||||
{
|
||||
};
|
||||
|
||||
//
|
||||
// WaveSize=64 tests (auto-skip on wave32 devices)
|
||||
//
|
||||
#if defined(CK_USE_XDL)
|
||||
TEST_P(TestAtomicKernelHalf_1, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
@@ -259,72 +222,6 @@ TEST_P(TestAtomicKernelFloat_4, TestCorrectness)
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
|
||||
//
|
||||
// WaveSize=32 tests (auto-skip on wave64 devices)
|
||||
//
|
||||
#else
|
||||
TEST_P(TestAtomicKernelHalf_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelHalf_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelHalf_4, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_4, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF8_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF8_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFP8_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFP8_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFloat_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFloat_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
#endif
|
||||
|
||||
// Common parameter lists
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelHalf_1,
|
||||
@@ -398,10 +295,8 @@ INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelFloat_4,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
#endif
|
||||
|
||||
@@ -26,14 +26,21 @@ struct AtomicKernelShape
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = MWarps;
|
||||
static constexpr index_t WarpPerBlock_N = NWarps;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
static constexpr index_t RepeatInWarp =
|
||||
Warp_M * Warp_N / Vector_M / Vector_N / ck_tile::get_warp_size();
|
||||
static constexpr index_t RepeatInWarp_M =
|
||||
(Warp_M / Vector_M > Warp_N / Vector_N) ? RepeatInWarp : 1;
|
||||
static constexpr index_t RepeatInWarp_N =
|
||||
(Warp_M / Vector_M > Warp_N / Vector_N) ? 1 : RepeatInWarp;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M / RepeatInWarp_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N / RepeatInWarp_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{});
|
||||
|
||||
@@ -54,7 +61,10 @@ struct AtomicKernel
|
||||
using XDataType = typename Problem::XDataType;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return ck_tile::is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeTileDistribution()
|
||||
{
|
||||
|
||||
@@ -2,7 +2,7 @@ add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_da
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_executable(test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_bwd_data_xdl_large_cases.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
|
||||
elseif(DL_KERNELS)
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp)
|
||||
if((GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
||||
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
|
||||
else()
|
||||
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_fwd_large_cases_xdl.cpp)
|
||||
target_compile_options(test_grouped_convnd_fwd_large_cases_xdl PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_fwd_large_cases_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx12")
|
||||
#Fail on gfx11 CI but fail to reproduce it in local, disable it temporary
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
|
||||
@@ -13,7 +16,6 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance)
|
||||
|
||||
add_executable(test_grouped_convnd_fwd_bias_clamp_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases.cpp)
|
||||
target_compile_options(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
|
||||
Reference in New Issue
Block a user