Adding validation for tile sizes in Tile Engine (#2189)

* Adding validation for tile sizes

* Add architecture in config, and shuffle lines of code in warp_gemm.hpp

* Enable MFMA for gfx950, and invalid tile handling
This commit is contained in:
Khushbu Agarwal
2025-05-15 10:28:31 -07:00
committed by GitHub
parent 7c0e29cc0f
commit 3d8d6e75e4
4 changed files with 96 additions and 38 deletions

View File

@@ -204,14 +204,6 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterat
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
@@ -221,20 +213,28 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>>>;

View File

@@ -1092,7 +1092,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
}
else
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
@@ -1116,7 +1116,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
@@ -1251,7 +1251,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}
else
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
@@ -1286,7 +1286,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));