fix after merge ginolu/add_wgmfma_dispatcher

This commit is contained in:
mtgu0705
2025-09-09 04:37:42 -05:00
parent f119c30317
commit b0d71b8d19
9 changed files with 1037 additions and 339 deletions

View File

@@ -21,7 +21,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename MXFlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
@@ -36,17 +36,17 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using BlockGemm = remove_cvref_t<typename GemmPipeline::BlockGemm>;
using MThreadPerXdl = BlockGemm::WarpGemm::kM;
using NThreadPerXdl = BlockGemm::WarpGemm::kN;
using KThreadPerXdl = get_warp_size() / MThreadPerXdl;
using BlockGemm = remove_cvref_t<typename MXFlatmmPipeline_::BlockGemm>;
static constexpr int MThreadPerXdl = BlockGemm::WarpGemm::kM;
static constexpr int NThreadPerXdl = BlockGemm::WarpGemm::kN;
static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr int MXdlPack = remove_cvref_t<typename FlatmmPipeline::MXdlPack>;
static constexpr int NXdlPack = remove_cvref_t<typename FlatmmPipeline::NXdlPack>;
static constexpr int KXdlPack = remove_cvref_t<typename FlatmmPipeline::KXdlPack>;
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -55,6 +55,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
static constexpr auto I5 = number<5>();
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
@@ -76,7 +77,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
constexpr int block_size = FlatmmPipeline::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
@@ -86,7 +87,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
F16xMXF4FlatmmKernel,
FlatmmPipeline,
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);
@@ -201,7 +202,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
auto scale_a = kargs.scale_m_ptr;
auto scale_b = kargs.scale_n_ptr;
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
// A scale tensor view
const auto& scale_a_tensor_view = [&]() {
@@ -215,7 +216,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
scale_a_naive_desc,
make_tuple(
make_merge_transform(
make_tuple(Padded_Scale_M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
make_tuple(kargs.M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
make_merge_transform(make_tuple(
kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl), KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
@@ -397,8 +398,8 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
// 32>{}),
// {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
auto scale_a = kargs.scale_m_ptr;
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
// auto scale_a = kargs.scale_m_ptr;
static constexpr int BlockScaleSize = 32;
auto scale_a_block_window = make_tile_window(
views.at(I4),