Support for MFMA_16x16x128 for fp8/bf8 (#2125)

* Adding 16x16x128 support for gfx950

* Support for fp8 and bf8

* fix input arguments for MFMA scale instruction

* clang-formatted

* Fixes for lwpck-3145 (#2138)

* Fix lds tile & cmake dep & default epilogue

* Fallback BTypeToUse to ADataType in WOQ cases

* reverting instance json file

* reverting instance json file

---------

Co-authored-by: Yi DING <yi.ding@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-04-28 18:19:50 -07:00
committed by GitHub
parent 768c99eca9
commit d107f3c3a5
8 changed files with 143 additions and 10 deletions

View File

@@ -25,7 +25,9 @@ struct Default2DEpilogueProblem
static constexpr bool UseRawStore = UseRawStore_;
};
template <typename AccDataType_,
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename ODataType_,
typename CLayout_,
bool kPadM_,
@@ -38,6 +40,8 @@ template <typename AccDataType_,
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
@@ -96,17 +100,22 @@ struct Default2DEpilogue
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using Problem = remove_cvref_t<Problem_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,