Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug

This commit is contained in:
aska-0096
2025-08-05 07:23:51 +00:00
parent 0d12fc944f
commit 414cad667b
2 changed files with 479 additions and 157 deletions

View File

@@ -17,6 +17,9 @@
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
// can remove all bank conflicts, but drop the performance for some cases
// Probably it is limited by compiler optimization.
#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
@@ -756,33 +759,102 @@ struct FmhaFwdDecodeKernel
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
const auto q_dram_unmerged = transform_tensor_view(
q_dram_pad,
make_tuple(make_pass_through_transform(seqlen_q),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
if constexpr(XorLengthFold > 1)
{
const auto q_dram_unmerged = transform_tensor_view(
q_dram_pad,
make_tuple(make_unmerge_transform(
make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
make_pass_through_transform(number<FmhaPipeline::kQKHeaddim>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
const auto q_dram_permuted = transform_tensor_view(
q_dram_unmerged,
make_tuple(make_xor_transform(make_tuple(
seqlen_q,
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
const auto q_dram_merged = transform_tensor_view(
q_dram_unmerged,
make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
make_merge_transform_v3_division_mod(make_tuple(
XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return transform_tensor_view(
q_dram_permuted,
make_tuple(make_pass_through_transform(seqlen_q),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto q_dram_unmerged_xor = transform_tensor_view(
q_dram_merged,
make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
make_unmerge_transform(make_tuple(
number<LDSLayerSize / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto q_dram_permuted = transform_tensor_view(
q_dram_unmerged_xor,
make_tuple(
make_xor_transform(
make_tuple(seqlen_q / XorLengthFold,
number<LDSLayerSize / FmhaPipeline::kAlignmentQ>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
const auto q_dram_tmp = transform_tensor_view(
q_dram_permuted,
make_tuple(
make_pass_through_transform(seqlen_q / XorLengthFold),
make_unmerge_transform(make_tuple(
number<XorLengthFold>{},
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_view(
q_dram_tmp,
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
{
const auto q_dram_unmerged = transform_tensor_view(
q_dram_pad,
make_tuple(
make_pass_through_transform(seqlen_q),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto q_dram_permuted = transform_tensor_view(
q_dram_unmerged,
make_tuple(
make_xor_transform(make_tuple(
seqlen_q,
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_view(
q_dram_permuted,
make_tuple(
make_pass_through_transform(seqlen_q),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
else
{
@@ -806,32 +878,96 @@ struct FmhaFwdDecodeKernel
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
const auto k_dram_unmerged = transform_tensor_view(
k_dram_pad,
make_tuple(make_pass_through_transform(height),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
const auto k_dram_permuted = transform_tensor_view(
k_dram_unmerged,
make_tuple(
make_xor_transform(make_tuple(
height, number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
if constexpr(XorLengthFold > 1)
{
const auto k_dram_unmerged = transform_tensor_view(
k_dram_pad,
make_tuple(
make_unmerge_transform(make_tuple(height / XorLengthFold, XorLengthFold)),
make_pass_through_transform(number<FmhaPipeline::kQKHeaddim>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_view(
k_dram_permuted,
make_tuple(make_pass_through_transform(height),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto k_dram_merged = transform_tensor_view(
k_dram_unmerged,
make_tuple(make_pass_through_transform(height / XorLengthFold),
make_merge_transform_v3_division_mod(
make_tuple(XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto k_dram_unmerged_xor = transform_tensor_view(
k_dram_merged,
make_tuple(make_pass_through_transform(height / XorLengthFold),
make_unmerge_transform(
make_tuple(number<LDSLayerSize / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto k_dram_permuted = transform_tensor_view(
k_dram_unmerged_xor,
make_tuple(make_xor_transform(
make_tuple(height / XorLengthFold,
number<LDSLayerSize / FmhaPipeline::kAlignmentK>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
const auto k_dram_tmp = transform_tensor_view(
k_dram_permuted,
make_tuple(make_pass_through_transform(height / XorLengthFold),
make_unmerge_transform(make_tuple(
number<XorLengthFold>{},
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_view(
k_dram_tmp,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
{
const auto k_dram_unmerged = transform_tensor_view(
k_dram_pad,
make_tuple(make_pass_through_transform(height),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto k_dram_permuted = transform_tensor_view(
k_dram_unmerged,
make_tuple(make_xor_transform(make_tuple(
height,
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_view(
k_dram_permuted,
make_tuple(make_pass_through_transform(height),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
@@ -852,41 +988,102 @@ struct FmhaFwdDecodeKernel
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
// TODO: Add kVHeadDim
constexpr index_t XorGroupSize =
FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
const auto v_dram_pad = pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenK, false>{});
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
// TODO: Add kVHeadDim
// TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements
constexpr auto XorGroupSize =
FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
if constexpr(XorLengthFold > 1)
{
const auto v_dram_unmerged = transform_tensor_view(
v_dram_pad,
make_tuple(
make_unmerge_transform(make_tuple(length / XorLengthFold, XorLengthFold)),
make_pass_through_transform(number<FmhaPipeline::kQKHeaddim>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
const auto v_dram_unmerged = transform_tensor_view(
v_dram_pad,
make_tuple(make_pass_through_transform(length),
make_unmerge_transform(
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto v_dram_merged = transform_tensor_view(
v_dram_unmerged,
make_tuple(make_pass_through_transform(length / XorLengthFold),
make_merge_transform_v3_division_mod(
make_tuple(XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto v_dram_permuted = transform_tensor_view(
v_dram_unmerged,
make_tuple(make_xor_transform(make_tuple(
length, number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
const auto v_dram_unmerged_xor = transform_tensor_view(
v_dram_merged,
make_tuple(make_pass_through_transform(length / XorLengthFold),
make_unmerge_transform(make_tuple(
number<LDSLayerSize / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
return transform_tensor_view(
v_dram_permuted,
make_tuple(make_pass_through_transform(length),
make_merge_transform_v3_division_mod(
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto v_dram_permuted = transform_tensor_view(
v_dram_unmerged_xor,
make_tuple(make_xor_transform(make_tuple(
length / XorLengthFold, number<LDSLayerSize / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
const auto v_dram_tmp = transform_tensor_view(
v_dram_permuted,
make_tuple(make_pass_through_transform(length / XorLengthFold),
make_unmerge_transform(
make_tuple(number<XorLengthFold>{},
number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_view(
v_dram_tmp,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
{
const auto v_dram_unmerged = transform_tensor_view(
v_dram_pad,
make_tuple(make_pass_through_transform(length),
make_unmerge_transform(
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto v_dram_permuted = transform_tensor_view(
v_dram_unmerged,
make_tuple(make_xor_transform(make_tuple(
length, number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_view(
v_dram_permuted,
make_tuple(make_pass_through_transform(length),
make_merge_transform_v3_division_mod(
make_tuple(number<FmhaPipeline::kQKHeaddim / XorGroupSize>{},
number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
};
const auto v_dram = [&]() {
@@ -917,8 +1114,8 @@ struct FmhaFwdDecodeKernel
auto v_dram_window = make_tile_window(
v_dram, make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), {0, 0});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
/// FIXME: Before C++20, capturing structured binding variables are not supported.
/// Remove following copy capture of the 'i_nhead' if in C++20
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
@@ -1091,8 +1288,8 @@ struct FmhaFwdDecodeKernel
const auto o_acc_dram_naive = [&] {
if constexpr(kMergeNumHeadGroupsSeqLenQ)
{
// reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
// hdim_v)
// reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk *
// seqlen_q, hdim_v)
const auto view = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),

View File

@@ -13,6 +13,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
// can remove all bank conflicts, but drop the performance for some cases
// Probably it is limited by compiler optimization.
#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
@@ -220,35 +223,75 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr auto q_lds_block_desc = [&]() {
if constexpr(Xor)
{
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType);
constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
const auto q_lds_block_desc_unmerged = transform_tensor_descriptor(
q_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
if constexpr(XorLengthFold > 1)
{
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / XorLengthFold>{},
number<LDSLayerSize / kKPack>{},
number<kKPack>{}),
make_tuple(number<LDSLayerSize>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
const auto q_lds_block_desc_permuted = transform_tensor_descriptor(
q_lds_block_desc_unmerged,
make_tuple(make_xor_transform(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
q_lds_block_desc_naive,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / XorLengthFold>{},
number<LDSLayerSize / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
q_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor(
q_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<kMPerBlock / XorLengthFold>{}),
make_unmerge_transform(
make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_descriptor(
q_lds_block_desc_tmp,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(
number<kMPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
{
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(
number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
q_lds_block_desc_naive,
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock>{},
number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
q_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
else
{
@@ -275,35 +318,75 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr auto k_lds_block_desc = [&]() {
if constexpr(Xor)
{
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType);
constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
const auto k_lds_block_desc_unmerged = transform_tensor_descriptor(
k_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
if constexpr(XorLengthFold > 1)
{
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / XorLengthFold>{},
number<LDSLayerSize / kKPack>{},
number<kKPack>{}),
make_tuple(number<LDSLayerSize>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
const auto k_lds_block_desc_permuted = transform_tensor_descriptor(
k_lds_block_desc_unmerged,
make_tuple(make_xor_transform(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
k_lds_block_desc_naive,
make_tuple(
make_xor_transform(make_tuple(number<kNPerBlock / XorLengthFold>{},
number<LDSLayerSize / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
k_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor(
k_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<kNPerBlock / XorLengthFold>{}),
make_unmerge_transform(
make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_descriptor(
k_lds_block_desc_tmp,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(
number<kNPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<kNPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
{
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(
number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
k_lds_block_desc_naive,
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock>{},
number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
k_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
else
{
@@ -332,35 +415,77 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr auto XorGroupSize =
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
make_tuple(number<kNPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType);
constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock;
const auto v_lds_block_desc_unmerged = transform_tensor_descriptor(
v_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_unmerge_transform(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
if constexpr(XorLengthFold > 1)
{
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / XorLengthFold>{},
number<LDSLayerSize / XorGroupSize>{},
number<XorGroupSize>{}),
make_tuple(number<LDSLayerSize>{}, number<XorGroupSize>{}, number<1>{}),
number<kKPack>{},
number<1>{});
const auto v_lds_block_desc_permuted = transform_tensor_descriptor(
v_lds_block_desc_unmerged,
make_tuple(make_xor_transform(make_tuple(number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
v_lds_block_desc_naive,
make_tuple(
make_xor_transform(make_tuple(number<kKPerBlock / XorLengthFold>{},
number<LDSLayerSize / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
v_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor(
v_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<kKPerBlock / XorLengthFold>{}),
make_unmerge_transform(make_tuple(number<XorLengthFold>{},
number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_descriptor(
v_lds_block_desc_tmp,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(
number<kKPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
{
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{},
number<XorGroupSize>{}),
make_tuple(number<kNPerBlock>{}, number<XorGroupSize>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
v_lds_block_desc_naive,
make_tuple(make_xor_transform(make_tuple(
number<kKPerBlock>{}, number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
v_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<kKPerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
else
{