Shuffle fix for gfx950 (#3491)

* solve compiler issue

* solve the gfx950 mfma shuffle regression

* refactor jenkinsfile to handle arch name better

* [CK TILE] set divisor to count of thread along k dimension

* fix the compiler error

* solve degradation

* Finish the multiplies fix

* fix the scales

* solve compilation error

* solve the composes

* solve the error of tile sweeper

* fix the test and example

* fix for gfx950

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Thomas Ning
2026-01-14 01:21:29 +08:00
committed by GitHub
parent 9908a87c31
commit 00c46785a8
33 changed files with 161 additions and 152 deletions

View File

@@ -1193,39 +1193,40 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
return make_composes(saturates<ck_tile::fp8_t>{},
scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales{scale_o};
return scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
else
{

View File

@@ -1538,10 +1538,11 @@ struct FmhaFwdKernel
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
return make_composes(
ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales{scale_o};
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
@@ -1553,9 +1554,10 @@ struct FmhaFwdKernel
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{
scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,

View File

@@ -1325,30 +1325,32 @@ struct FmhaFwdPagedKVKernel
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func
v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func
v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
kargs.scale_p}, // p_compute_element_func
make_composes(saturates<fp8_t>{},
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}
else
{

View File

@@ -457,14 +457,15 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
smem_ptr);
return FmhaPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
make_composes(saturates<fp8_t>{},
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
smem_ptr);
}
else
{

View File

@@ -1069,10 +1069,11 @@ struct FmhaFwdSplitKVKernel
bias_dram_window,
identity{}, // bias_element_func
lse_acc_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
identity{}, // o_acc_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
kargs.scale_p}, // p_compute_element_func
identity{}, // o_acc_element_func
kargs.num_splits,
i_split_,
mask,