mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA
This commit is contained in:
@@ -38,8 +38,8 @@ K0_MAX_SUBMAX_MAP = {
|
||||
}
|
||||
|
||||
SEQLENQ_MAP = {
|
||||
# "16" : "16",
|
||||
# "32" : "32",
|
||||
"16" : "16",
|
||||
"32" : "32",
|
||||
# "64" : "64"
|
||||
"128" : "128",
|
||||
# "256" : "256",
|
||||
@@ -133,18 +133,18 @@ using trait_{F_idx} = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{
|
||||
//if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
// && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||
// || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||
// if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||
// instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
|
||||
// }} else {{
|
||||
// instance<kHasUnevenSplits>::run(s, a);
|
||||
// }}
|
||||
//}} else {{
|
||||
// instance<kHasUnevenSplits>::run(s, a);
|
||||
//}}
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
// instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}} // anonymous namespace
|
||||
|
||||
@@ -153,20 +153,20 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{
|
||||
template<>
|
||||
void fmha_fwd_decode_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_decode_args a)
|
||||
{{
|
||||
//if constexpr({F_mode} == false) {{ // batch mode
|
||||
// // we don't check every seqlen_k values for kvcache
|
||||
// if (a.seqlen_k_ptr != nullptr) {{
|
||||
// run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// // make sure F_bn0 is divisible by F_bk1
|
||||
// }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
// run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
// }} else {{
|
||||
// run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// }}
|
||||
//}} else {{
|
||||
// run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
//}}
|
||||
run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
// we don't check every seqlen_k values for kvcache
|
||||
if (a.seqlen_k_ptr != nullptr) {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// make sure F_bn0 is divisible by F_bk1
|
||||
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
}} else {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
// run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
}}
|
||||
|
||||
template<>
|
||||
@@ -659,15 +659,15 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
# '64': {
|
||||
'64': {
|
||||
# Specialize for different SeqQ
|
||||
# '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# },
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
},
|
||||
'128': {
|
||||
# '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128': FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256': FmhaFwdTileSize(256, 64, 32, 128, 16, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
},
|
||||
|
||||
@@ -95,7 +95,15 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
template <index_t lgkmcnt = 0>
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(lgkmcnt));
|
||||
if constexpr(lgkmcnt > 15)
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(15));
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(lgkmcnt));
|
||||
}
|
||||
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
|
||||
{
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
|
||||
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
|
||||
|
||||
thread_buffer<T, 2> v;
|
||||
v(0) = bit_cast<T>(x[0]);
|
||||
v(1) = bit_cast<T>(x[1]);
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
{
|
||||
|
||||
@@ -227,14 +227,20 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
auto q_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<QDataType*>(smem_ptr), Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_store_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto q_lds_read_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<QDataType*>(smem_ptr),
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem, true>());
|
||||
|
||||
auto q_lds_store_window =
|
||||
make_tile_window(q_lds_write_view,
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tile_window(q_lds_read_view,
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
@@ -452,7 +458,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
// Set CrossWarp to false will trigger better strategy on gfx950, but will cause
|
||||
// performance regression because of un-coexecutable packed math, silent it for now
|
||||
block_tile_reduce_sync(
|
||||
m_local, f_max, bool_constant<false>{} /*, bool_constant<false>{}*/);
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
@@ -505,7 +514,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
block_tile_reduce_sync(
|
||||
rowsum_p, f_sum, bool_constant<false>{} /*, bool_constant<false>{}*/);
|
||||
|
||||
auto p_tile = make_static_distributed_tensor<PDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>());
|
||||
@@ -964,7 +974,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
block_tile_reduce_sync(
|
||||
m_local, f_max, bool_constant<false>{} /*, bool_constant<false>{}*/);
|
||||
|
||||
static_for<0, 12, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
@@ -1029,7 +1040,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
block_tile_reduce_sync(
|
||||
rowsum_p, f_sum, bool_constant<false>{} /*, bool_constant<false>{}*/);
|
||||
|
||||
auto p_tile = make_static_distributed_tensor<PDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>());
|
||||
|
||||
@@ -14,10 +14,14 @@ namespace ck_tile {
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
|
||||
template <typename AccDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
bool WithBroadcast = true,
|
||||
bool CrossWarp = true>
|
||||
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
bool_constant<WithBroadcast> = {})
|
||||
bool_constant<WithBroadcast> = {},
|
||||
bool_constant<CrossWarp> = {})
|
||||
{
|
||||
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
@@ -56,14 +60,24 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
constexpr index_t lid_delta =
|
||||
lid_over_rid_derivative * (1 << (nstage - istage - 1));
|
||||
if constexpr(CrossWarp)
|
||||
{
|
||||
constexpr index_t lid_delta =
|
||||
lid_over_rid_derivative * (1 << (nstage - istage - 1));
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
// pull data from remote lane
|
||||
const auto v_swapped_regs = warp_shuffle_down_pair(v_local);
|
||||
// reduce
|
||||
v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user