Load Q through lds, implement xor;

This commit is contained in:
aska-0096
2025-08-04 06:49:01 +00:00
parent 2d4e73d2b4
commit 746f4ccb99
4 changed files with 99 additions and 24 deletions

View File

@@ -42,6 +42,7 @@ SEQLENQ_MAP = {
# "32" : "32",
# "64" : "64"
"128" : "128",
# "256" : "256",
}
FMHA_FWD_DECODE_PIPELINE_MAP = {
@@ -668,6 +669,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
# '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128': FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256': FmhaFwdTileSize(256, 64, 32, 128, 16, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1),
},
}
else:

View File

@@ -590,7 +590,7 @@ struct FmhaFwdDecodeKernel
// 1. we expect KV data reused by different ThreadGroups, use cache
// 2. use more LDS, as we want better memory latency hiding
// If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the cache
constexpr bool PrefillCase = FmhaPipeline::kM0 == 128;
constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128;
// divide problem
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
@@ -751,10 +751,38 @@ struct FmhaFwdDecodeKernel
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(
const auto seqlen_q = kargs.seqlen_q;
const auto q_dram_pad = pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
const auto q_dram_unmerged = transform_tensor_view(
q_dram_pad,
make_tuple(make_pass_through_transform(seqlen_q),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto q_dram_permuted = transform_tensor_view(
q_dram_unmerged,
make_tuple(make_xor_transform(make_tuple(
seqlen_q,
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_view(
q_dram_permuted,
make_tuple(make_pass_through_transform(seqlen_q),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
number<FmhaPipeline::kAlignmentQ>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{

View File

@@ -104,7 +104,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || kM0 >= 256)
return 1;
else
return 2;
@@ -728,23 +728,30 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
// Q tile in LDS
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem, true>());
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
// auto q_lds = make_tensor_view<address_space_enum::lds>(
// static_cast<QDataType*>(smem_ptrk0),
// Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_view = make_tensor_view<address_space_enum::lds>(
static_cast<QDataType*>(smem_ptrk0),
Policy::template MakeQLdsBlockDescriptor<Problem>());
// auto q_lds_store_window = make_tile_window(
// q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto q_lds_read_view = make_tensor_view<address_space_enum::lds>(
static_cast<QDataType*>(smem_ptrk0),
Policy::template MakeQLdsBlockDescriptor<Problem, true>());
// auto q_lds_read_window =
// make_tile_window(q_lds,
// Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
// {0, 0},
// Policy::template MakeQRegTileDistribution<Problem>());
auto q_lds_store_window =
make_tile_window(q_lds_write_view,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// async_load_tile(q_lds_store_window, q_dram_window);
auto q_tile = load_tile(q_dram_window);
auto q_lds_read_window =
make_tile_window(q_lds_read_view,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem>());
async_load_tile(q_lds_store_window, q_dram_window);
block_sync_lds_direct_load<0>();
auto q_tile = load_tile(q_lds_read_window);
// K tile in LDS
const index_t physical_seqlen_k_start = logical_seqlen_k_start;
@@ -825,6 +832,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
block_sync_lds<0>();
async_load_tile(k_lds_write_window, k_dram_window);
async_load_tile(v_lds_write_window, v_dram_window);

View File

@@ -54,7 +54,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
@@ -209,7 +209,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
return static_cast<index_t>(16 / sizeof(QDataType));
}
template <typename Problem>
template <typename Problem, bool Xor = false>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
@@ -217,11 +217,48 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr auto q_lds_block_desc =
make_naive_tensor_descriptor(make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = [&]() {
if constexpr(Xor)
{
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
const auto q_lds_block_desc_unmerged = transform_tensor_descriptor(
q_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto q_lds_block_desc_permuted = transform_tensor_descriptor(
q_lds_block_desc_unmerged,
make_tuple(make_xor_transform(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
q_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
}
}();
return q_lds_block_desc;
}