mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Load Q through lds, implement xor;
This commit is contained in:
@@ -42,6 +42,7 @@ SEQLENQ_MAP = {
|
||||
# "32" : "32",
|
||||
# "64" : "64"
|
||||
"128" : "128",
|
||||
# "256" : "256",
|
||||
}
|
||||
|
||||
FMHA_FWD_DECODE_PIPELINE_MAP = {
|
||||
@@ -668,6 +669,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
# '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128': FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256': FmhaFwdTileSize(256, 64, 32, 128, 16, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
},
|
||||
}
|
||||
else:
|
||||
|
||||
@@ -590,7 +590,7 @@ struct FmhaFwdDecodeKernel
|
||||
// 1. we expect KV data reused by different ThreadGroups, use cache
|
||||
// 2. use more LDS, as we want better memory latency hiding
|
||||
// If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the cache
|
||||
constexpr bool PrefillCase = FmhaPipeline::kM0 == 128;
|
||||
constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128;
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
@@ -751,10 +751,38 @@ struct FmhaFwdDecodeKernel
|
||||
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
const auto seqlen_q = kargs.seqlen_q;
|
||||
const auto q_dram_pad = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto q_dram_unmerged = transform_tensor_view(
|
||||
q_dram_pad,
|
||||
make_tuple(make_pass_through_transform(seqlen_q),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto q_dram_permuted = transform_tensor_view(
|
||||
q_dram_unmerged,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
seqlen_q,
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{})),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentQ>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
q_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(seqlen_q),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentQ>{},
|
||||
number<FmhaPipeline::kAlignmentQ>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -104,7 +104,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || kM0 >= 256)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
@@ -728,23 +728,30 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
// Q tile in LDS
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem, true>());
|
||||
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
// auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
// static_cast<QDataType*>(smem_ptrk0),
|
||||
// Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<QDataType*>(smem_ptrk0),
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
// auto q_lds_store_window = make_tile_window(
|
||||
// q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto q_lds_read_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<QDataType*>(smem_ptrk0),
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem, true>());
|
||||
|
||||
// auto q_lds_read_window =
|
||||
// make_tile_window(q_lds,
|
||||
// Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
// {0, 0},
|
||||
// Policy::template MakeQRegTileDistribution<Problem>());
|
||||
auto q_lds_store_window =
|
||||
make_tile_window(q_lds_write_view,
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
// async_load_tile(q_lds_store_window, q_dram_window);
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_read_view,
|
||||
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
async_load_tile(q_lds_store_window, q_dram_window);
|
||||
block_sync_lds_direct_load<0>();
|
||||
auto q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
// K tile in LDS
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start;
|
||||
@@ -825,6 +832,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
block_sync_lds<0>();
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
|
||||
|
||||
@@ -209,7 +209,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
return static_cast<index_t>(16 / sizeof(QDataType));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool Xor = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
@@ -217,11 +217,48 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
constexpr auto q_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
constexpr auto q_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
const auto q_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
q_lds_block_desc_naive,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_unmerged,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return q_lds_block_desc;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user