mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
fix after merge ginolu/add_wgmfma_dispatcher
This commit is contained in:
@@ -21,7 +21,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename MXFlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
@@ -36,17 +36,17 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<typename GemmPipeline::BlockGemm>;
|
||||
using MThreadPerXdl = BlockGemm::WarpGemm::kM;
|
||||
using NThreadPerXdl = BlockGemm::WarpGemm::kN;
|
||||
using KThreadPerXdl = get_warp_size() / MThreadPerXdl;
|
||||
using BlockGemm = remove_cvref_t<typename MXFlatmmPipeline_::BlockGemm>;
|
||||
static constexpr int MThreadPerXdl = BlockGemm::WarpGemm::kM;
|
||||
static constexpr int NThreadPerXdl = BlockGemm::WarpGemm::kN;
|
||||
static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
|
||||
|
||||
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr int MXdlPack = remove_cvref_t<typename FlatmmPipeline::MXdlPack>;
|
||||
static constexpr int NXdlPack = remove_cvref_t<typename FlatmmPipeline::NXdlPack>;
|
||||
static constexpr int KXdlPack = remove_cvref_t<typename FlatmmPipeline::KXdlPack>;
|
||||
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -55,6 +55,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
static constexpr auto I5 = number<5>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
@@ -76,7 +77,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
|
||||
constexpr int block_size = FlatmmPipeline::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
@@ -86,7 +87,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
F16xMXF4FlatmmKernel,
|
||||
FlatmmPipeline,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
@@ -201,7 +202,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
|
||||
// A scale tensor view
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
@@ -215,7 +216,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
scale_a_naive_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
|
||||
make_tuple(kargs.M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(
|
||||
kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl), KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
@@ -397,8 +398,8 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
// 32>{}),
|
||||
// {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
|
||||
// auto scale_a = kargs.scale_m_ptr;
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
|
||||
Reference in New Issue
Block a user