mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 5 (#2725)
* Enable xdl in gfx11 & gfx12 * update cmake file * fix all instance build (cmake) * fix batched_gemm_gemm(cmake) * rebase cmake files * fix cmake build error * remve CK_ENABLE_DYNAMIC_WARP_SIZE * update cmake build error2 * fix gfx11 build CK_USE_XDL is enabled on gfx11 and gfx12 * fix gfx10 build * fix gfx11 error --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
This commit is contained in:
@@ -153,7 +153,7 @@ list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
@@ -173,11 +173,13 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
|
||||
|
||||
@@ -92,7 +92,8 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
#if defined(CK_USE_XDL) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -166,8 +167,8 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
#if defined(CK_USE_XDL) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(
|
||||
|
||||
@@ -103,7 +103,8 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
#endif
|
||||
@@ -167,7 +168,8 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
@@ -201,7 +203,8 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
|
||||
|
||||
@@ -104,7 +104,8 @@ int profile_gemm_universal_preshuffle(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
@@ -163,7 +164,8 @@ int profile_gemm_universal_preshuffle(int argc, char* argv[])
|
||||
{
|
||||
return profile(F8{}, F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || CK_USE_OCP_FP8 || defined(CK_USE_GFX94) || \
|
||||
defined(CK_USE_WMMA_FP8)
|
||||
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
|
||||
Reference in New Issue
Block a user