mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Use better naming for tile indices
This commit is contained in:
@@ -299,10 +299,10 @@ struct FmhaFwdAppendKVKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_sk, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}();
|
||||
|
||||
const index_t i_sk = __builtin_amdgcn_readfirstlane(i_tile_sk * FmhaPipeline::kTileSizeSk);
|
||||
// const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kTileSizeS);
|
||||
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kTileSizeSk);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
@@ -514,9 +514,8 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
/// TODO: use tile idx for q
|
||||
return make_tile_window(
|
||||
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
|
||||
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -540,9 +539,8 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
/// TODO: use tile idx for q
|
||||
return make_tile_window(
|
||||
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
|
||||
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -570,7 +568,7 @@ struct FmhaFwdAppendKVKernel
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
|
||||
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -595,7 +593,7 @@ struct FmhaFwdAppendKVKernel
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
|
||||
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -603,31 +601,30 @@ struct FmhaFwdAppendKVKernel
|
||||
}
|
||||
}();
|
||||
|
||||
/// TODO: use tile idx for q
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeS>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
{0, 0});
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
{kargs.seqlen_k + i_sk, 0});
|
||||
{kargs.seqlen_k + i_n0, 0});
|
||||
|
||||
auto knew_dram_window = make_tile_window(
|
||||
knew_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
{i_sk, 0});
|
||||
{i_n0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kTileSizeSk>{}),
|
||||
{0, kargs.seqlen_k + i_sk});
|
||||
{0, kargs.seqlen_k + i_n0});
|
||||
|
||||
auto vnew_dram_window = make_tile_window(
|
||||
vnew_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kTileSizeSk>{}),
|
||||
{0, i_sk});
|
||||
{0, i_n0});
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
@@ -640,6 +637,8 @@ struct FmhaFwdAppendKVKernel
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0,
|
||||
smem_ptr,
|
||||
kargs.rotary_dim);
|
||||
}
|
||||
@@ -654,6 +653,8 @@ struct FmhaFwdAppendKVKernel
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0,
|
||||
smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,35 +19,23 @@ struct FmhaFwdAppendKVTilePartitioner
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_v)
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_knew)
|
||||
{
|
||||
assert(ck_tile::integer_divide_ceil(hdim_v, kTileSizeD) == 1);
|
||||
#ifdef NDEBUG
|
||||
ignore = hdim_v;
|
||||
#endif
|
||||
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk), nhead, batch_size);
|
||||
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kTileSizeS),
|
||||
ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk)),
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
CK_TILE_DEVICE auto operator()()
|
||||
{
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_tile = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
(void)f;
|
||||
// const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user