INT8 GFX12

Signed-off-by: root <tianyuwu@amd.com>
This commit is contained in:
root
2025-07-08 17:48:00 +08:00
parent 20b11c9424
commit 98fe3af7dc
5 changed files with 47 additions and 6 deletions

View File

@@ -91,6 +91,9 @@ using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16_gfx12 =
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16_gfx12 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8_gfx12 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>>;
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8_gfx12 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>>;

View File

@@ -30,6 +30,31 @@ struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
}
};
// int8 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
{
template <bool clamp = false>
CK_TILE_DEVICE static CVecType
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
{
#ifdef __gfx12__
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a
bit_cast<int32x2_t>(a_vec),
true, // neg_b
bit_cast<int32x2_t>(b_vec),
bit_cast<int32x8_t>(c_vec),
clamp);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
// fp8/bf8 specialization - GFX12
template <>
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>

View File

@@ -39,15 +39,12 @@ template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::ha
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16_gfx12<TransposeC>;};
#else
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
#endif
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
//TODO: currently int8 in this location; need to move
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx11<TransposeC>;};
// fp16 2:4 structural sparsity
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
@@ -111,9 +108,15 @@ template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8
// int8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
#if defined(__gfx11__)
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx11<TransposeC>;};
#else // __gfx12__
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx12<TransposeC>;};
#endif
// clang-format on
} // namespace impl

View File

@@ -26,6 +26,10 @@ template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_bf16_bf16_gfx12 = WarpGemmImpl<
WarpGemmAtrributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16_gfx12, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_i32_16x16x16_i8_i8_gfx12 = WarpGemmImpl<
WarpGemmAtrributeWmma<WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8_gfx12, kTransC>>;
template <bool kTransC = false>
using WarpGemmWmma_f32_16x16x16_f8_f8_gfx12 = WarpGemmImpl<
WarpGemmAtrributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8_gfx12, kTransC>>;

View File

@@ -58,6 +58,8 @@ using KernelTypesMemWmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, Mem>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
@@ -94,18 +96,22 @@ using KernelTypesCompV3 = ::testing::Types<
using KernelTypesCompV3Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Row, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Col, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Col, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Row, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Col, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Col, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>
>;