update hotloop pipeline

This commit is contained in:
mtgu0705
2025-09-08 04:01:40 -05:00
parent 146963d62a
commit 0509597f55
2 changed files with 294 additions and 250 deletions

View File

@@ -504,8 +504,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_dram_window,
const ScaleBDramBlockWindowTmp& scale_b_flat_window,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
@@ -622,8 +622,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
// auto scale_b_flat_distribution =
// PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
@@ -631,12 +629,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// auto scale_b_flat_dram_window = make_tile_window(
// scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
// make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
// scale_b_flat_window.get_window_origin(),
// scale_b_flat_distribution);
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_Buffer = thread_buffer<uint32_t, XDL_PerWeightK>;
@@ -651,37 +643,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
// statically_indexed_array<
// statically_indexed_array<decltype(scale_b_flat_dram_window), ScaleKPerWarp>,
// ScaleNPerWarp>
// scale_b_flat_dram_windows;
// statically_indexed_array<
// statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)),
// ScaleKPerWarp>, ScaleNPerWarp> scale_b_warp_tensor_ping;
// statically_indexed_array<
// statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)),
// ScaleKPerWarp>, ScaleNPerWarp> scale_b_warp_tensor_pong;
// pingpong buffer for Scale A and Scale B
auto scale_a_dram_window = make_tile_window(
scale_a_dram_window.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock / MXdlPack>{}, number<kKPerBlock / KXdlPack>{}),
scale_a_draw_window.get_window_origin(),
scale_a_window.get_bottom_tensor_view(),
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
scale_a_window.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ScaleA_DramTileDistribution<Problem>());
auto scale_b_dram_winodow = make_tile_window(
scale_b_dram_window.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock / NXdlPack>{}, number<kKPerBlock / KXdlPack>{}),
scale_b_dram_window.get_window_origin(),
auto scale_b_dram_window = make_tile_window(
scale_b_window.get_bottom_tensor_view(),
make_tuple(number<NWarp * WG::kN>{}, number<64 / WG::kN>{}),
scale_b_window.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
// ping pong buffer for scale A
statically_indexed_array<
statically_indexed_array<decltype(scale_a_dram_window), KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_dram_windows;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_tile_tensor_ping;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_tile_tensor_pong;
// ping pong buffer for scale B
statically_indexed_array<
statically_indexed_array<decltype(scale_b_dram_window), KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_dram_windows;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_ping;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_pong;
// HEAD
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
@@ -690,27 +699,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
@@ -718,11 +713,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
// prefetch Scale A and Scale B
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter)(kIter),
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
scale_a_tile_tensor_ping(mIter)(kIter) =
load_tile(scale_a_dram_windows(mIter)(kIter));
});
});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter)(kIter),
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
scale_b_tile_tensor_ping(nIter)(kIter) =
load_tile(scale_b_dram_windows(nIter)(kIter));
});
});
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// A_Lds_TileDist may differ with ADramTileDistribution
auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
@@ -751,110 +769,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
__builtin_amdgcn_sched_barrier(0);
statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp> dequant_B_n;
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
auto xdl_kIter) {
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
constexpr int float_mantissa = 23;
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
using ComputeV2Type =
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
#if defined(__gfx950__)
auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(
pk_mxfp4x4, fscale, int(byte_idx));
}
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
{
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(
pk_mxfp4x4, fscale, int(byte_idx));
}
else
{
static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type");
}
};
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4x4_to_compute_v2(
quant_weight_tensor[quant_idx_k], bit_cast<float>(uscale), i));
});
#else
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
}
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
{
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
}
else
{
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
}
};
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4_to_compute_v2(
bit_cast<thread_buffer<pk_fp4_t, 4>>(quant_weight_tensor[quant_idx_k])
.at(i),
bit_cast<float>(uscale)));
});
#endif
};
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter,
scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
@@ -862,6 +790,29 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter)(kIter),
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
scale_a_tile_tensor_pong(mIter)(kIter) =
load_tile(scale_a_dram_windows(mIter)(kIter));
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter)(kIter),
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
scale_b_tile_tensor_pong(nIter)(kIter) =
load_tile(scale_b_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+1)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
@@ -872,55 +823,74 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxd * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
}
});
});
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
@@ -933,32 +903,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// Next K
// prefetch B(2i+2)
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter,
scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
@@ -966,6 +919,29 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// prefetch Scale A and Scale B (2i+2)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter)(kIter),
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
scale_a_tile_tensor_ping(mIter)(kIter) =
load_tile(scale_a_dram_windows(mIter)(kIter));
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter)(kIter),
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
scale_b_tile_tensor_ping(nIter)(kIter) =
load_tile(scale_b_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+2)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
@@ -976,54 +952,74 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxd * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
}
});
});
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
@@ -1032,8 +1028,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
iCounter--;
}
// TAIL

View File

@@ -155,9 +155,9 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NWavePerBlk, NXdlPack>, // second
// direction
sequence<WaveRepeat>, // ?
tuple<sequence<NWavePerBlk, NXdlPack>, // second >>>>>>>>>need to double confirm
// direction
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
@@ -273,6 +273,56 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
constexpr index_t N_Lane = TileShape::WarpTile::at(I0);
constexpr index_t MWavePerBlk = M_Warp;
return make_static_tile_distribution(
tile_distributed_encoding<sequence<>, // ?
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1>, sequence<2, 1>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distributed_encoding<sequence<>, // ?
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1>, sequence<2, 1>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
};
} // namespace ck_tile