[CK_TILE] Add Flatmm MX FP8 (#3208)

* Use async for flatmm mxfp4

* Fix preshuffle

* Add flatmm mxfp8

* Thanks, Copilot

* Thanks Copilot again~
This commit is contained in:
Yi DING
2025-11-20 10:35:15 +08:00
committed by GitHub
parent 4e49e0228b
commit 47e2ed838e
17 changed files with 698 additions and 595 deletions

View File

@@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
}
}();
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
auto&& desc = transform_tensor_descriptor(
naive_desc,
make_tuple(make_pass_through_transform(kFlatN),
make_merge_transform_v3_division_mod(
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
}();
const auto& ds_tensor_view = generate_tuple(

View File

@@ -44,7 +44,10 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
else if(TailNumber::Odd == tail_num)
return TailHandler<DispatchHotloop, TailNumber::Odd>(run_func, has_hot_loop);
else
{
assert(("Wrong TailNumber!", false));
return decltype(TailHandler<>(run_func, true, TailNumber::Even)){};
}
}
};

View File

@@ -43,7 +43,7 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
static constexpr int NXdlPack = 2; // it's fixed for fp4
static constexpr int KXdlPack = 2;
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
};
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
@@ -122,9 +122,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
@@ -138,25 +139,25 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static constexpr index_t mfma_per_wg = 1; // 950 only
static constexpr index_t dsread_per_wg =
WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize / Problem::VectorLoadSize;
static_assert((WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize) %
Problem::VectorLoadSize ==
0);
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize;
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0);
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_num_perK = dsread_num_perK / NWarp;
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
static constexpr index_t Bload_num = Bload_num_perK * KIterPerWarp;
static constexpr index_t ScaleBload_num =
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 / WaveSize;
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
kNPerBlock * kKPerBlock / NWarp / ScaleGranularityK / NXdlPack / KXdlPack / WaveSize;
static constexpr index_t ScaleAload_num =
kMPerBlock * kKPerBlock / MWarp / ScaleGranularityK / MXdlPack / KXdlPack / WaveSize;
// static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
@@ -219,7 +220,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
{
@@ -234,7 +235,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
{
@@ -470,18 +471,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
}
template <typename ADramBlockWindowTmp,
typename AElementFunction,
typename BFlatBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_ping,
void* __restrict__ p_smem_pong) const
{
#ifndef __gfx950__
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
@@ -495,9 +494,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
// const index_t iNWarp = get_warp_id() % NWarp;
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
static_assert(NWarp == 4);
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
@@ -506,6 +504,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto a_dram_window =
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
a_copy_dram_window_tmp.get_bottom_tensor_view()),
a_copy_dram_window_tmp.get_window_lengths(),
a_copy_dram_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
@@ -520,93 +525,51 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_store_lds_window_ping = make_tile_window(
a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window_pong = make_tile_window(
a_lds_block_pong, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
auto a_warp_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
auto a_warp_window_pong_tmp =
auto a_warp_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_ping;
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
auto packed_m_idx = mIter / number<MXdlPack>{};
auto packed_m_rank = mIter % number<MXdlPack>{};
move_tile_window(
a_warp_windows_ping(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
move_tile_window(
a_warp_windows_pong(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
});
});
// Block GEMM
auto block_flatmm = BlockFlatmm();
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_B_Buffer = thread_buffer<uint32_t, 4>;
union UnionBuf
{
V4UInt_B_Buffer u = 0;
MXFP4_B_Buffer mxfp4;
} ub;
// pingpong buffer for B
auto b_flat_dram_windows = generate_tuple(
[&](auto nIter) {
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
auto window_i = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
move_tile_window(
window_i,
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
number<0>{}});
return window_i;
},
number<NIterPerWarp>{});
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
b_warp_tensor_ping, b_warp_tensor_pong;
// pingpong buffer for Scale A and Scale B
auto scale_a_dram_window = make_tile_window(
@@ -649,29 +612,24 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_pong;
auto async_load_tile_ = [](auto lds, auto dram) {
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
};
// HEAD
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
async_load_tile_(a_store_lds_window_ping, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
});
// move B window to next flat K
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
// prefetch Scale A
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -700,71 +658,40 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
// move Scale B window to next K
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// A_Lds_TileDist may differ with ADramTileDistribution
auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
if constexpr(HasHotLoop || TailNum == TailNumber::Even)
{
async_load_tile_(a_store_lds_window_pong, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
}
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
clear_tile(c_block_tile);
block_sync_lds();
using MXFP4_A_Buffer_ping =
decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{})));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_A_Buffer = thread_buffer<uint32_t, 4>;
union UnionBuf_A_ping
{
V4UInt_A_Buffer u = 0;
MXFP4_A_Buffer_ping mxfp4;
} ua_ping;
using MXFP4_A_Buffer_pong =
decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{})));
union UnionBuf_A_pong
{
V4UInt_A_Buffer u = 0;
MXFP4_A_Buffer_pong mxfp4;
} ua_pong;
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
// preload A00,A10... from lds
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
s_waitcnt_barrier</*vmcnt*/ dswrite_num_perK>();
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
auto main_body_implx2 = [&]() mutable {
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
});
});
@@ -791,15 +718,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// Prefill A(2i+1)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -807,30 +725,26 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -838,68 +752,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_ping,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
});
});
});
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
s_waitcnt< // vmcnt
Bload_num + ScaleAload_num + ScaleBload_num>();
block_sync_lds();
// Prefetch A(2i+2)
async_load_tile_(a_store_lds_window_ping, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// preload A(2i+1)
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
HotLoopScheduler();
// Next K
////////////////////////////// Next K //////////////////////////////
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
});
});
@@ -926,15 +832,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// Prefill A(2i+2)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
// Prefetch A(2i+3)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -953,20 +850,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_pong ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -988,39 +878,47 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_pong,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
});
});
});
// barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished
s_waitcnt< // vmcnt
Bload_num + ScaleAload_num + ScaleBload_num>();
block_sync_lds();
// Prefetch A(2i+3)
async_load_tile_(a_store_lds_window_pong, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// preload A(2i+2)
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
HotLoopScheduler();
};
iCounter--;
if constexpr(HasHotLoop)
{
index_t iCounter = (num_loop - 1) / 2;
do
{
main_body_implx2();
iCounter--;
} while(iCounter > 0);
}
// TAIL
@@ -1029,18 +927,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter),
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
});
});
@@ -1055,7 +944,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
@@ -1067,10 +955,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// Prefill A(loopK)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
// GEMM loopK-1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -1089,20 +973,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -1124,30 +1001,28 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_ping,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
});
});
});
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
s_waitcnt< // vmcnt
Bload_num + ScaleAload_num + ScaleBload_num>();
block_sync_lds();
// preload A(2i+1)
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
Last2ndHotLoopScheduler();
@@ -1170,19 +1045,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_pong ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -1204,18 +1073,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_pong,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
@@ -1244,20 +1106,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -1279,18 +1134,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_ping,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
@@ -1299,32 +1147,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
LastHotLoopScheduler();
}
else
{
static_assert(false, "Wrong TailNum");
}
return c_block_tile;
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp,
const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
b_flat_dram_block_window_tmp,
scale_a_flat_window_tmp,
scale_b_flat_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
}
};
} // namespace ck_tile

View File

@@ -13,22 +13,139 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t KBPerLoad = 32;
static constexpr index_t kDramLoadPackBytes = 128;
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
{
using namespace ck_tile;
static inline constexpr auto wg_attr_num_access =
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
? WGAttrNumAccessEnum::Single
: WGAttrNumAccessEnum::Double;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(
sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
"sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
ADataType,
BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access<Problem>>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
ADataType,
BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
make_pass_through_transform(K0),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
const auto desc = transform_tensor_descriptor( //
desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
return tensor_view<typename TensorView::buffer_view,
remove_cvref_t<decltype(desc)>,
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t M2 = get_warp_size() / K1; // 8
constexpr index_t M1 = BlockSize / get_warp_size(); // 4
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // M0,K0,K2
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
@@ -36,65 +153,70 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
constexpr index_t Pad = 4 * K2; // 4 * 32
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<M0>{},
number<M1>{},
number<K0>{},
number<M2>{},
number<M3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
number<M2 * M3 * K1 * K2 + Pad>{},
number<M3 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(M1),
make_pass_through_transform(K0),
make_pass_through_transform(M2),
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
a_lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
{
@@ -105,20 +227,31 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
constexpr int M_Lane = TileShape::WarpTile::at(I0);
constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
constexpr int K_Lane = 64 / M_Lane; // 4
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr int K1 = K_Thread / num_access_v; // 16
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>{});
std::conditional_t<
num_access_v == 1,
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>,
tile_distribution_encoding< //
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>>{});
}
template <typename Problem>
@@ -132,25 +265,36 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t K1 = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t K0 = KWavePerBlk;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t kKPerThread = 32;
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr index_t K2 = kKPerThread / num_access_v;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>,
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<2>>{});
std::conditional_t< //
num_access_v == 1,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<K0, K1, K2>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>>{});
}
template <typename Problem>
@@ -270,6 +414,21 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
return sizeof(ADataType) *
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return GetSmemSizeA<Problem>();
}
};
} // namespace ck_tile