mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
fix f4 pipeline issues
This commit is contained in:
@@ -417,7 +417,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
// Local prefetch 1, sync the async load
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
static_for<0, MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, math::min(2 * MXdlPack, MRepeat), 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
@@ -465,9 +465,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(scale_mem_buf));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
|
||||
//block_sync_lds();
|
||||
//a_blockwise_copy.Run(
|
||||
// a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
@@ -512,10 +512,22 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
//a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
if constexpr(m0.value == (MRepeat/ MXdlPack - LocalPrefetchStages))
|
||||
{
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto lds_buf =
|
||||
m0.value >= (MRepeat/ MXdlPack - LocalPrefetchStages)
|
||||
? scale_mem_buf
|
||||
: scale_comp_buf;
|
||||
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
constexpr index_t a_scale_offset =
|
||||
@@ -602,15 +614,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
});
|
||||
});
|
||||
});
|
||||
if constexpr(m0.value == (MRepeat/ MXdlPack - LocalPrefetchStages))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
constexpr auto lds_buf =
|
||||
m0.value >= (MRepeat/ MXdlPack - LocalPrefetchStages)
|
||||
? scale_mem_buf
|
||||
: scale_comp_buf;
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
@@ -642,6 +645,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
// HotLoopScheduler();
|
||||
@@ -1090,13 +1094,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<2>{}, I1, Number<MXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}),
|
||||
make_tuple(Number<KPack * MXdlPack>{},
|
||||
Number<KRepeat * MRepeat * KPack>{},
|
||||
Number<MRepeat * KPack>{},
|
||||
Number<KPack>{},
|
||||
I1));
|
||||
// Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack]
|
||||
// Order: 1 0 3 2 4
|
||||
static constexpr auto ARegBuf = 2;
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor(make_tuple(Number<ARegBuf>{},
|
||||
I1,
|
||||
Number<MXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<KPack>{}),
|
||||
make_tuple(Number<KRepeat * MXdlPack* KPack>{},
|
||||
Number<ARegBuf * MXdlPack * KRepeat * KPack>{},
|
||||
Number<KPack>{},
|
||||
Number<MXdlPack*KPack>{},
|
||||
I1));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeTypeA,
|
||||
|
||||
Reference in New Issue
Block a user