mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4592 (commit 45f76cb)
Tile Engine support for gfx950 ## Motivation This PR adds support for the gfx950 GPU architecture to the Tile Engine in Composable Kernel library, focusing on GEMM operations with FP8 and BF8 data types. ## Technical Details Added gfx950-specific MFMA warp GEMM implementations with conditional compilation. Updated default GEMM configuration parameters for tile sizes and warp configurations. Added Jenkins CI pipeline stage for testing TILE_ENGINE_GEMM on gfx950 hardware. ## Test Plan Tile engine itself is a benchmarking utility, so if it passes the CI it will be tested automatically. ## Test Result Tile engine itself is a benchmarking utility, so if it passes the CI it will be tested automatically. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
17e4c8eac9
commit
9059730caf
@@ -293,6 +293,44 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterat
|
||||
4>>;
|
||||
|
||||
// fp8
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
@@ -309,14 +347,6 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
@@ -335,10 +365,6 @@ using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
@@ -105,10 +105,7 @@ template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, true> { u
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
@@ -118,7 +115,6 @@ template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 32, false> { using Ty
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
@@ -146,6 +142,16 @@ template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, fal
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EQuad>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EQuad>; };
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<EDouble>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<EDouble>; };
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<EDouble>; };
|
||||
|
||||
|
||||
//WMMA cases
|
||||
template<bool TransposeC> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
|
||||
Reference in New Issue
Block a user