diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 320f801e5f..4ea27046ab 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -91,6 +91,9 @@ using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16_gfx12 = using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16_gfx12 = WarpGemmAttributeWmmaImpl>; +using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8_gfx12 = + WarpGemmAttributeWmmaImpl>; + using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8_gfx12 = WarpGemmAttributeWmmaImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp index 9ca6d29cbd..81ff5af2fe 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -30,6 +30,31 @@ struct WmmaTraits } }; +// int8 specialization - GFX12 +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + template + CK_TILE_DEVICE static CVecType + wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) + { +#ifdef __gfx12__ + return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a + bit_cast(a_vec), + true, // neg_b + bit_cast(b_vec), + bit_cast(c_vec), + clamp); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + // fp8/bf8 specialization - GFX12 template <> struct WmmaTraits diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 59a53de11b..84160752dd 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -39,15 +39,12 @@ template struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmWmma_f32_16x16x16_f16_f16_gfx12;}; #else template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; #endif template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; -//TODO: currently int8 in this location; need to move -template struct WarpGemmDispatcher { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx11;}; - // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; @@ -111,9 +108,15 @@ template struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; }; -template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; + +#if defined(__gfx11__) +template struct WarpGemmDispatcher { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx11;}; +#else // __gfx12__ +template struct WarpGemmDispatcher { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx12;}; +#endif // clang-format on } // namespace impl diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp index e4be5373a6..cdd97ecf79 100644 --- a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp @@ -26,6 +26,10 @@ template using WarpGemmWmma_f32_16x16x16_bf16_bf16_gfx12 = WarpGemmImpl< WarpGemmAtrributeWmma>; +template +using WarpGemmWmma_i32_16x16x16_i8_i8_gfx12 = WarpGemmImpl< + WarpGemmAtrributeWmma>; + template using WarpGemmWmma_f32_16x16x16_f8_f8_gfx12 = WarpGemmImpl< WarpGemmAtrributeWmma>; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index b38cf4691e..731fe290e1 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -58,6 +58,8 @@ using KernelTypesMemWmma = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>, std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>, std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>, + std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>, + std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, Mem>, std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>, std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, Mem>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>, @@ -94,18 +96,22 @@ using KernelTypesCompV3 = ::testing::Types< using KernelTypesCompV3Wmma = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, + std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, + std::tuple< Row, Col, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, + std::tuple< Col, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, + std::tuple< Col, Col, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3> >;