mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
update hotloop pipeline
This commit is contained in:
@@ -504,8 +504,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_dram_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_flat_window,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
@@ -622,8 +622,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
|
||||
// auto scale_b_flat_distribution =
|
||||
// PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_flat_dram_window = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
@@ -631,12 +629,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// auto scale_b_flat_dram_window = make_tile_window(
|
||||
// scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
// make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
|
||||
// scale_b_flat_window.get_window_origin(),
|
||||
// scale_b_flat_distribution);
|
||||
|
||||
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_Buffer = thread_buffer<uint32_t, XDL_PerWeightK>;
|
||||
@@ -651,37 +643,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(scale_b_flat_dram_window), ScaleKPerWarp>,
|
||||
// ScaleNPerWarp>
|
||||
// scale_b_flat_dram_windows;
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)),
|
||||
// ScaleKPerWarp>, ScaleNPerWarp> scale_b_warp_tensor_ping;
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)),
|
||||
// ScaleKPerWarp>, ScaleNPerWarp> scale_b_warp_tensor_pong;
|
||||
|
||||
// pingpong buffer for Scale A and Scale B
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
scale_a_dram_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock / MXdlPack>{}, number<kKPerBlock / KXdlPack>{}),
|
||||
scale_a_draw_window.get_window_origin(),
|
||||
scale_a_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
|
||||
scale_a_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleA_DramTileDistribution<Problem>());
|
||||
|
||||
auto scale_b_dram_winodow = make_tile_window(
|
||||
scale_b_dram_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock / NXdlPack>{}, number<kKPerBlock / KXdlPack>{}),
|
||||
scale_b_dram_window.get_window_origin(),
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<NWarp * WG::kN>{}, number<64 / WG::kN>{}),
|
||||
scale_b_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
|
||||
|
||||
// ping pong buffer for scale A
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_a_dram_window), KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_tile_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_tile_tensor_pong;
|
||||
|
||||
// ping pong buffer for scale B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_b_dram_window), KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_pong;
|
||||
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -690,27 +699,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp});
|
||||
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -718,11 +713,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
|
||||
// prefetch Scale A and Scale B
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// A_Lds_TileDist may differ with ADramTileDistribution
|
||||
|
||||
auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
@@ -751,110 +769,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp> dequant_B_n;
|
||||
|
||||
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
auto xdl_kIter) {
|
||||
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
|
||||
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
|
||||
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
|
||||
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
constexpr int float_mantissa = 23;
|
||||
|
||||
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
|
||||
|
||||
using ComputeV2Type =
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(
|
||||
pk_mxfp4x4, fscale, int(byte_idx));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
|
||||
{
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(
|
||||
pk_mxfp4x4, fscale, int(byte_idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4x4_to_compute_v2(
|
||||
quant_weight_tensor[quant_idx_k], bit_cast<float>(uscale), i));
|
||||
});
|
||||
#else
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
|
||||
{
|
||||
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4_to_compute_v2(
|
||||
bit_cast<thread_buffer<pk_fp4_t, 4>>(quant_weight_tensor[quant_idx_k])
|
||||
.at(i),
|
||||
bit_cast<float>(uscale)));
|
||||
});
|
||||
#endif
|
||||
};
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter,
|
||||
scale_k_iter * ScaleKFlatPerWarp});
|
||||
|
||||
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -862,6 +790,29 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
@@ -872,55 +823,74 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
@@ -933,32 +903,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter,
|
||||
scale_k_iter * ScaleKFlatPerWarp});
|
||||
|
||||
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -966,6 +919,29 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+2)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
@@ -976,54 +952,74 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShap::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
@@ -1032,8 +1028,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// TAIL
|
||||
|
||||
@@ -155,9 +155,9 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // second
|
||||
// direction
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // second >>>>>>>>>need to double confirm
|
||||
// direction
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
@@ -273,6 +273,56 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr index_t MWavePerBlk = M_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distributed_encoding<sequence<>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distributed_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user