Shuffle fix for gfx950 (#3491)

* solve compiler issue

* solve the gfx950 mfma shuffle regression

* refactor jenkinsfile to handle arch name better

* [CK TILE] set divisor to count of thread along k dimension

* fix the compiler error

* solve degradation

* Finish the multiplies fix

* fix the scales

* solve compilation error

* solve the composes

* solve the error of tile sweeper

* fix the test and example

* fix for gfx950

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Thomas Ning
2026-01-14 01:21:29 +08:00
committed by GitHub
parent 9908a87c31
commit 00c46785a8
33 changed files with 161 additions and 152 deletions

View File

@@ -24,7 +24,7 @@ struct ElementWiseShape
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
static constexpr index_t kBlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -19,7 +19,8 @@ struct TileFlatmmShape
static constexpr auto idxN = number<1>{};
static constexpr auto idxK = number<2>{};
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t kM = BlockTile::at(idxM);
static constexpr index_t kN = BlockTile::at(idxN);

View File

@@ -1193,39 +1193,40 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
return make_composes(saturates<ck_tile::fp8_t>{},
scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales{scale_o};
return scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
else
{

View File

@@ -1538,10 +1538,11 @@ struct FmhaFwdKernel
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
return make_composes(
ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales{scale_o};
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
@@ -1553,9 +1554,10 @@ struct FmhaFwdKernel
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{
scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,

View File

@@ -1325,30 +1325,32 @@ struct FmhaFwdPagedKVKernel
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func
v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func
v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
kargs.scale_p}, // p_compute_element_func
make_composes(saturates<fp8_t>{},
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}
else
{

View File

@@ -457,14 +457,15 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
smem_ptr);
return FmhaPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
make_composes(saturates<fp8_t>{},
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
smem_ptr);
}
else
{

View File

@@ -1069,10 +1069,11 @@ struct FmhaFwdSplitKVKernel
bias_dram_window,
identity{}, // bias_element_func
lse_acc_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
identity{}, // o_acc_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
kargs.scale_p}, // p_compute_element_func
identity{}, // o_acc_element_func
kargs.num_splits,
i_split_,
mask,

View File

@@ -42,9 +42,9 @@ struct TileFmhaShape
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
@@ -95,10 +95,10 @@ struct TileFmhaBwdShape
using Gemm4WarpTile = remove_cvref_t<Gemm4WarpTile_>;
static constexpr index_t NumWarps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{}) &&
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies<>{}, number<1>{}));
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen

View File

@@ -56,10 +56,10 @@ struct FusedMoeGemmShape
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
reduce_on_sequence(WarpPerBlock_0{}, multiplies<>{}, number<1>{});
// TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies<>{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});

View File

@@ -19,7 +19,8 @@ struct TileGemmShape
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{});

View File

@@ -52,6 +52,6 @@ struct PoolShape
static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -345,7 +345,7 @@ struct BlockReduce2D
constexpr auto row_y_unpacks = [&]() {
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
constexpr auto row_y_size =
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
reduce_on_sequence(row_y_lengths, multiplies<>{}, number<1>{});
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
static_assert(row_y_size % row_y_packs == 0);

View File

@@ -39,6 +39,6 @@ struct Reduce2dShape
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -96,7 +96,7 @@ struct TopkSoftmaxWarpPerRowPipeline
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
}
};
tile_sweeper ts{w_, w_f};
tile_sweeper<decltype(w_), decltype(w_f)> ts{w_, w_f};
ts();
return w_;
#endif