mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
@@ -91,6 +91,9 @@ using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16_gfx12 =
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16_gfx12 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8_gfx12 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8_gfx12 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>>;
|
||||
|
||||
|
||||
@@ -30,6 +30,31 @@ struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
// int8 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a
|
||||
bit_cast<int32x2_t>(a_vec),
|
||||
true, // neg_b
|
||||
bit_cast<int32x2_t>(b_vec),
|
||||
bit_cast<int32x8_t>(c_vec),
|
||||
clamp);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// fp8/bf8 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
|
||||
|
||||
@@ -39,15 +39,12 @@ template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::ha
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16_gfx12<TransposeC>;};
|
||||
#else
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
#endif
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
//TODO: currently int8 in this location; need to move
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx11<TransposeC>;};
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
@@ -111,9 +108,15 @@ template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
|
||||
#if defined(__gfx11__)
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx11<TransposeC>;};
|
||||
#else // __gfx12__
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8_gfx12<TransposeC>;};
|
||||
#endif
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
@@ -26,6 +26,10 @@ template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_bf16_bf16_gfx12 = WarpGemmImpl<
|
||||
WarpGemmAtrributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16_gfx12, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_i32_16x16x16_i8_i8_gfx12 = WarpGemmImpl<
|
||||
WarpGemmAtrributeWmma<WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8_gfx12, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_f8_f8_gfx12 = WarpGemmImpl<
|
||||
WarpGemmAtrributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8_gfx12, kTransC>>;
|
||||
|
||||
@@ -58,6 +58,8 @@ using KernelTypesMemWmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Interwave, Mem>,
|
||||
@@ -94,18 +96,22 @@ using KernelTypesCompV3 = ::testing::Types<
|
||||
using KernelTypesCompV3Wmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, I8, I8, I32, I32, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, ck_tile::number<64>, ck_tile::number<32>, ck_tile::number<16>, ck_tile::number<16>, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
Reference in New Issue
Block a user