diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bb69300a6..b28a6d9127 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,12 +183,14 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") message("Enabling XDL instances") add_definitions(-DCK_USE_XDL) - set(CK_USE_XDL "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") + message("Enabling FP8 gemms in ckProfiler") + add_definitions(-DCK_USE_GFX94) endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) - set(CK_USE_WMMA "ON") endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp index 576bd009b6..990cbd292e 100644 --- a/profiler/src/profile_gemm_universal.cpp +++ b/profiler/src/profile_gemm_universal.cpp @@ -101,7 +101,7 @@ int profile_gemm_universal(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) using F8 = ck::f8_t; #endif @@ -164,7 +164,7 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); } -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); @@ -198,7 +198,7 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); } -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{}); diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl.cpp index 23b5c74ddd..b872d7089a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl.cpp @@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) std::tuple< F16, F8, F16, F16>, std::tuple< F8, F16, F16, F16>, std::tuple< F8, F8, F8, BF16>, @@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types< using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) std::tuple< F16, F8, F16, F16>, std::tuple< F8, F16, F16, F16>, std::tuple< F8, F8, F8, BF16>,