fix f4 pipeline issues

This commit is contained in:
joye
2025-05-23 17:13:10 +08:00
parent 97709c4aa1
commit 8afac88f89

View File

@@ -417,7 +417,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
static_for<0, MXdlPack, 1>{}([&](auto m0) {
static_for<0, math::min(2 * MXdlPack, MRepeat), 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
@@ -465,9 +465,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
b_block_origin_idx,
b_thread_bufs(scale_mem_buf));
block_sync_lds();
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
//block_sync_lds();
//a_blockwise_copy.Run(
// a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
@@ -512,10 +512,22 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
//a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
if constexpr(m0.value == (MRepeat/ MXdlPack - LocalPrefetchStages))
{
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
}
constexpr auto lds_buf =
m0.value >= (MRepeat/ MXdlPack - LocalPrefetchStages)
? scale_mem_buf
: scale_comp_buf;
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
@@ -602,15 +614,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
});
});
if constexpr(m0.value == (MRepeat/ MXdlPack - LocalPrefetchStages))
{
block_sync_lds();
}
constexpr auto lds_buf =
m0.value >= (MRepeat/ MXdlPack - LocalPrefetchStages)
? scale_mem_buf
: scale_comp_buf;
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
@@ -642,6 +645,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
});
});
});
// HotLoopScheduler();
@@ -1090,13 +1094,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
}
}
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<2>{}, I1, Number<MXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(Number<KPack * MXdlPack>{},
Number<KRepeat * MRepeat * KPack>{},
Number<MRepeat * KPack>{},
Number<KPack>{},
I1));
// Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack]
// Order: 1 0 3 2 4
static constexpr auto ARegBuf = 2;
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<ARegBuf>{},
I1,
Number<MXdlPack>{},
Number<KRepeat>{},
Number<KPack>{}),
make_tuple(Number<KRepeat * MXdlPack* KPack>{},
Number<ARegBuf * MXdlPack * KRepeat * KPack>{},
Number<KPack>{},
Number<MXdlPack*KPack>{},
I1));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,