mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Shuffle fix for gfx950 (#3491)
* solve compiler issue * solve the gfx950 mfma shuffle regression * refactor jenkinsfile to handle arch name better * [CK TILE] set divisor to count of thread along k dimension * fix the compiler error * solve degradation * Finish the multiplies fix * fix the scales * solve compilation error * solve the composes * solve the error of tile sweeper * fix the test and example * fix for gfx950 --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
@@ -24,7 +24,7 @@ struct ElementWiseShape
|
||||
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -19,7 +19,8 @@ struct TileFlatmmShape
|
||||
static constexpr auto idxN = number<1>{};
|
||||
static constexpr auto idxK = number<2>{};
|
||||
|
||||
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr index_t kM = BlockTile::at(idxM);
|
||||
static constexpr index_t kN = BlockTile::at(idxN);
|
||||
|
||||
@@ -1193,39 +1193,40 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
return make_composes(saturates<ck_tile::fp8_t>{},
|
||||
scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
|
||||
else
|
||||
return ck_tile::scales{scale_o};
|
||||
return scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -1538,10 +1538,11 @@ struct FmhaFwdKernel
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
return make_composes(
|
||||
ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
|
||||
else
|
||||
return ck_tile::scales{scale_o};
|
||||
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
@@ -1553,9 +1554,10 @@ struct FmhaFwdKernel
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{scale_p}, // p_compute_element_func
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{
|
||||
scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
|
||||
@@ -1325,30 +1325,32 @@ struct FmhaFwdPagedKVKernel
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
|
||||
kargs.scale_p}, // p_compute_element_func
|
||||
make_composes(saturates<fp8_t>{},
|
||||
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
|
||||
kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -457,14 +457,15 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(
|
||||
lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
smem_ptr);
|
||||
return FmhaPipeline{}(lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
make_composes(saturates<fp8_t>{},
|
||||
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
|
||||
kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -1069,10 +1069,11 @@ struct FmhaFwdSplitKVKernel
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
lse_acc_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
|
||||
kargs.scale_p}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
i_split_,
|
||||
mask,
|
||||
|
||||
@@ -42,9 +42,9 @@ struct TileFmhaShape
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
@@ -95,10 +95,10 @@ struct TileFmhaBwdShape
|
||||
using Gemm4WarpTile = remove_cvref_t<Gemm4WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
|
||||
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{}) &&
|
||||
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies<>{}, number<1>{}));
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
|
||||
@@ -56,10 +56,10 @@ struct FusedMoeGemmShape
|
||||
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(WarpPerBlock_0{}, multiplies<>{}, number<1>{});
|
||||
|
||||
// TODO: we don't support half warps aound to 1 warp here
|
||||
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
|
||||
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies<>{}, number<1>{}));
|
||||
|
||||
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
|
||||
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
|
||||
|
||||
@@ -19,7 +19,8 @@ struct TileGemmShape
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
using WarpTile = remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr index_t kM = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
|
||||
@@ -52,6 +52,6 @@ struct PoolShape
|
||||
static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -345,7 +345,7 @@ struct BlockReduce2D
|
||||
constexpr auto row_y_unpacks = [&]() {
|
||||
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
|
||||
constexpr auto row_y_size =
|
||||
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(row_y_lengths, multiplies<>{}, number<1>{});
|
||||
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
|
||||
|
||||
static_assert(row_y_size % row_y_packs == 0);
|
||||
|
||||
@@ -39,6 +39,6 @@ struct Reduce2dShape
|
||||
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -96,7 +96,7 @@ struct TopkSoftmaxWarpPerRowPipeline
|
||||
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
|
||||
}
|
||||
};
|
||||
tile_sweeper ts{w_, w_f};
|
||||
tile_sweeper<decltype(w_), decltype(w_f)> ts{w_, w_f};
|
||||
ts();
|
||||
return w_;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user