mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Support for MFMA_16x16x128 for fp8/bf8 (#2125)
* Adding 16x16x128 support for gfx950 * Support for fp8 and bf8 * fix input arguments for MFMA scale instruction * clang-formatted * Fixes for lwpck-3145 (#2138) * Fix lds tile & cmake dep & default epilogue * Fallback BTypeToUse to ADataType in WOQ cases * reverting instance json file * reverting instance json file --------- Co-authored-by: Yi DING <yi.ding@amd.com>
This commit is contained in:
@@ -25,7 +25,9 @@ struct Default2DEpilogueProblem
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
};
|
||||
|
||||
template <typename AccDataType_,
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename CLayout_,
|
||||
bool kPadM_,
|
||||
@@ -38,6 +40,8 @@ template <typename AccDataType_,
|
||||
struct DefaultGemm2DEpilogueProblem
|
||||
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
@@ -96,17 +100,22 @@ struct Default2DEpilogue
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ODataType,
|
||||
ODataType,
|
||||
using WG = WarpGemmMfmaDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
|
||||
Reference in New Issue
Block a user