mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
fix after merge ginolu/add_wgmfma_dispatcher
This commit is contained in:
@@ -21,7 +21,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename MXFlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
@@ -36,17 +36,17 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<typename GemmPipeline::BlockGemm>;
|
||||
using MThreadPerXdl = BlockGemm::WarpGemm::kM;
|
||||
using NThreadPerXdl = BlockGemm::WarpGemm::kN;
|
||||
using KThreadPerXdl = get_warp_size() / MThreadPerXdl;
|
||||
using BlockGemm = remove_cvref_t<typename MXFlatmmPipeline_::BlockGemm>;
|
||||
static constexpr int MThreadPerXdl = BlockGemm::WarpGemm::kM;
|
||||
static constexpr int NThreadPerXdl = BlockGemm::WarpGemm::kN;
|
||||
static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
|
||||
|
||||
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr int MXdlPack = remove_cvref_t<typename FlatmmPipeline::MXdlPack>;
|
||||
static constexpr int NXdlPack = remove_cvref_t<typename FlatmmPipeline::NXdlPack>;
|
||||
static constexpr int KXdlPack = remove_cvref_t<typename FlatmmPipeline::KXdlPack>;
|
||||
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -55,6 +55,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
static constexpr auto I5 = number<5>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
@@ -76,7 +77,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
|
||||
constexpr int block_size = FlatmmPipeline::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
@@ -86,7 +87,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
F16xMXF4FlatmmKernel,
|
||||
FlatmmPipeline,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
@@ -201,7 +202,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
|
||||
// A scale tensor view
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
@@ -215,7 +216,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
scale_a_naive_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
|
||||
make_tuple(kargs.M / (MXdlPack * MThreadPerXdl), MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(
|
||||
kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl), KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
@@ -397,8 +398,8 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
// 32>{}),
|
||||
// {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
static constexpr int BlockScaleSize = decltype(scale_n)::GranularityK;
|
||||
// auto scale_a = kargs.scale_m_ptr;
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
|
||||
@@ -158,7 +158,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
|
||||
// static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
|
||||
|
||||
static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
// static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -178,7 +178,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread * ContinuousScaleKPerThread;
|
||||
static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
|
||||
static constexpr index_t ScaleBload_num =
|
||||
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
|
||||
WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
|
||||
@@ -631,7 +631,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_Buffer = thread_buffer<uint32_t, XDL_PerWeightK>;
|
||||
using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionB
|
||||
{
|
||||
V4UInt_Buffer u = 0;
|
||||
@@ -718,24 +718,24 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -793,23 +793,23 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -825,7 +825,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
@@ -850,7 +850,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -888,7 +888,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
@@ -905,8 +905,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
@@ -922,23 +922,23 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B (2i+2)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -979,7 +979,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1017,7 +1017,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
@@ -1036,15 +1036,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -1055,23 +1054,23 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1082,7 +1081,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
@@ -1107,7 +1106,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1181,7 +1180,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1224,7 +1223,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
@@ -1249,7 +1248,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1297,8 +1296,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramblockWindowTmp& scale_a_flat_window_tmp,
|
||||
const ScaleBDramblockWindowTmp& scale_b_flat_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
|
||||
@@ -13,7 +13,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
// static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
// static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
// static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
@@ -35,10 +35,10 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
@@ -117,9 +117,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr index_t MXdlPack = Problm::MXdlPack;
|
||||
constexpr int NWaves = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
constexpr int NWaves = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
|
||||
@@ -209,24 +208,24 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t kMPerBlock = tileShape::BlockTile::at(I0);
|
||||
constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
|
||||
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
static_assert(num_warps == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t K_Lanes = 64 / M_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
static constexpr index_t Y2 = M_Lanes;
|
||||
static constexpr index_t Y1 = M_Warps;
|
||||
static constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
|
||||
constexpr index_t Y2 = M_Lanes;
|
||||
constexpr index_t Y1 = M_Warps;
|
||||
constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
static constexpr index_t X0 = K_Lanes;
|
||||
static constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
constexpr index_t X0 = K_Lanes;
|
||||
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
|
||||
@@ -246,24 +245,24 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t kNPerBlock = tileShape::BlockTile::at(I1);
|
||||
constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1);
|
||||
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
static_assert(num_warps == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
|
||||
constexpr index_t K_Lanes = 64 / N_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
static constexpr index_t Y2 = N_Lanes;
|
||||
static constexpr index_t Y1 = N_Warps;
|
||||
static constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
|
||||
constexpr index_t Y2 = N_Lanes;
|
||||
constexpr index_t Y1 = N_Warps;
|
||||
constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
static constexpr index_t X0 = K_Lanes;
|
||||
static constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
constexpr index_t X0 = K_Lanes;
|
||||
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<M_Warps>, // ?
|
||||
@@ -284,19 +283,19 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr index_t MWavePerBlk = M_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distributed_encoding<sequence<>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -314,14 +313,14 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distributed_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user