mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug
This commit is contained in:
@@ -17,6 +17,9 @@
|
||||
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
|
||||
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
|
||||
|
||||
// can remove all bank conflicts, but drop the performance for some cases
|
||||
// Probably it is limited by compiler optimization.
|
||||
#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
@@ -756,33 +759,102 @@ struct FmhaFwdDecodeKernel
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
|
||||
const auto q_dram_unmerged = transform_tensor_view(
|
||||
q_dram_pad,
|
||||
make_tuple(make_pass_through_transform(seqlen_q),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
const auto q_dram_unmerged = transform_tensor_view(
|
||||
q_dram_pad,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
|
||||
make_pass_through_transform(number<FmhaPipeline::kQKHeaddim>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
const auto q_dram_permuted = transform_tensor_view(
|
||||
q_dram_unmerged,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
seqlen_q,
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
const auto q_dram_merged = transform_tensor_view(
|
||||
q_dram_unmerged,
|
||||
make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
q_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(seqlen_q),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
const auto q_dram_unmerged_xor = transform_tensor_view(
|
||||
q_dram_merged,
|
||||
make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<LDSLayerSize / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto q_dram_permuted = transform_tensor_view(
|
||||
q_dram_unmerged_xor,
|
||||
make_tuple(
|
||||
make_xor_transform(
|
||||
make_tuple(seqlen_q / XorLengthFold,
|
||||
number<LDSLayerSize / FmhaPipeline::kAlignmentQ>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
const auto q_dram_tmp = transform_tensor_view(
|
||||
q_dram_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(seqlen_q / XorLengthFold),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<XorLengthFold>{},
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
q_dram_tmp,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
const auto q_dram_unmerged = transform_tensor_view(
|
||||
q_dram_pad,
|
||||
make_tuple(
|
||||
make_pass_through_transform(seqlen_q),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto q_dram_permuted = transform_tensor_view(
|
||||
q_dram_unmerged,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(
|
||||
seqlen_q,
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
q_dram_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(seqlen_q),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -806,32 +878,96 @@ struct FmhaFwdDecodeKernel
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto k_dram_unmerged = transform_tensor_view(
|
||||
k_dram_pad,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
|
||||
const auto k_dram_permuted = transform_tensor_view(
|
||||
k_dram_unmerged,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(
|
||||
height, number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
const auto k_dram_unmerged = transform_tensor_view(
|
||||
k_dram_pad,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(height / XorLengthFold, XorLengthFold)),
|
||||
make_pass_through_transform(number<FmhaPipeline::kQKHeaddim>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
k_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
const auto k_dram_merged = transform_tensor_view(
|
||||
k_dram_unmerged,
|
||||
make_tuple(make_pass_through_transform(height / XorLengthFold),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto k_dram_unmerged_xor = transform_tensor_view(
|
||||
k_dram_merged,
|
||||
make_tuple(make_pass_through_transform(height / XorLengthFold),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<LDSLayerSize / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto k_dram_permuted = transform_tensor_view(
|
||||
k_dram_unmerged_xor,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(height / XorLengthFold,
|
||||
number<LDSLayerSize / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
const auto k_dram_tmp = transform_tensor_view(
|
||||
k_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(height / XorLengthFold),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<XorLengthFold>{},
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
k_dram_tmp,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
const auto k_dram_unmerged = transform_tensor_view(
|
||||
k_dram_pad,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto k_dram_permuted = transform_tensor_view(
|
||||
k_dram_unmerged,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
height,
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
k_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
};
|
||||
const auto k_dram = [&]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
@@ -852,41 +988,102 @@ struct FmhaFwdDecodeKernel
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO: Add kVHeadDim
|
||||
constexpr index_t XorGroupSize =
|
||||
FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
const auto v_dram_pad = pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenK, false>{});
|
||||
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
|
||||
// TODO: Add kVHeadDim
|
||||
// TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements
|
||||
constexpr auto XorGroupSize =
|
||||
FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
const auto v_dram_unmerged = transform_tensor_view(
|
||||
v_dram_pad,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(length / XorLengthFold, XorLengthFold)),
|
||||
make_pass_through_transform(number<FmhaPipeline::kQKHeaddim>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
const auto v_dram_unmerged = transform_tensor_view(
|
||||
v_dram_pad,
|
||||
make_tuple(make_pass_through_transform(length),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
const auto v_dram_merged = transform_tensor_view(
|
||||
v_dram_unmerged,
|
||||
make_tuple(make_pass_through_transform(length / XorLengthFold),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto v_dram_permuted = transform_tensor_view(
|
||||
v_dram_unmerged,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
length, number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
const auto v_dram_unmerged_xor = transform_tensor_view(
|
||||
v_dram_merged,
|
||||
make_tuple(make_pass_through_transform(length / XorLengthFold),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<LDSLayerSize / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
v_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(length),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
const auto v_dram_permuted = transform_tensor_view(
|
||||
v_dram_unmerged_xor,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
length / XorLengthFold, number<LDSLayerSize / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
const auto v_dram_tmp = transform_tensor_view(
|
||||
v_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(length / XorLengthFold),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<XorLengthFold>{},
|
||||
number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
v_dram_tmp,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
const auto v_dram_unmerged = transform_tensor_view(
|
||||
v_dram_pad,
|
||||
make_tuple(make_pass_through_transform(length),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto v_dram_permuted = transform_tensor_view(
|
||||
v_dram_unmerged,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
length, number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
v_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(length),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
};
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
@@ -917,8 +1114,8 @@ struct FmhaFwdDecodeKernel
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram, make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), {0, 0});
|
||||
|
||||
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
|
||||
/// following copy capture of the 'i_nhead' if in C++20
|
||||
/// FIXME: Before C++20, capturing structured binding variables are not supported.
|
||||
/// Remove following copy capture of the 'i_nhead' if in C++20
|
||||
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto bias_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
|
||||
@@ -1091,8 +1288,8 @@ struct FmhaFwdDecodeKernel
|
||||
const auto o_acc_dram_naive = [&] {
|
||||
if constexpr(kMergeNumHeadGroupsSeqLenQ)
|
||||
{
|
||||
// reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
|
||||
// hdim_v)
|
||||
// reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk *
|
||||
// seqlen_q, hdim_v)
|
||||
const auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
|
||||
// can remove all bank conflicts, but drop the performance for some cases
|
||||
// Probably it is limited by compiler optimization.
|
||||
#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
|
||||
namespace ck_tile {
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
@@ -220,35 +223,75 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr auto q_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType);
|
||||
constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
|
||||
|
||||
const auto q_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
q_lds_block_desc_naive,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<LDSLayerSize>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
const auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_unmerged,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_naive,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kMPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock / XorLengthFold>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
q_lds_block_desc_tmp,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kMPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_naive,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock>{},
|
||||
number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -275,35 +318,75 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr auto k_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType);
|
||||
constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
|
||||
|
||||
const auto k_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
k_lds_block_desc_naive,
|
||||
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<LDSLayerSize>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
const auto k_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
k_lds_block_desc_unmerged,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
k_lds_block_desc_naive,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kNPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
k_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor(
|
||||
k_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kNPerBlock / XorLengthFold>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
k_lds_block_desc_tmp,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kNPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
k_lds_block_desc_naive,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock>{},
|
||||
number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
k_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -332,35 +415,77 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr auto XorGroupSize =
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType);
|
||||
constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock;
|
||||
|
||||
const auto v_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / XorGroupSize>{},
|
||||
number<XorGroupSize>{}),
|
||||
make_tuple(number<LDSLayerSize>{}, number<XorGroupSize>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
v_lds_block_desc_unmerged,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kKPerBlock>{},
|
||||
number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kKPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor(
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kKPerBlock / XorLengthFold>{}),
|
||||
make_unmerge_transform(make_tuple(number<XorLengthFold>{},
|
||||
number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
v_lds_block_desc_tmp,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kKPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock>{},
|
||||
number<kNPerBlock / XorGroupSize>{},
|
||||
number<XorGroupSize>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<XorGroupSize>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
number<kKPerBlock>{}, number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user