diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index d64297cb35..930cdefb7e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -18,7 +18,7 @@ template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) @@ -35,7 +35,7 @@ constexpr ck_tile::index_t get_k_warp_tile() template constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; else diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 87772f78fc..f83bbc2a18 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -301,6 +301,30 @@ using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; +template +using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 5021fb9907..1d3dd2ae6f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -112,6 +112,10 @@ template<> struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp index cf9bf18c5a..61bb1a8bdd 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp @@ -18,7 +18,7 @@ template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32)