Use better naming for tile indices

This commit is contained in:
PoYen, Chen
2024-07-23 06:40:53 +00:00
parent bc7c7ee0c5
commit 0925c0e941
2 changed files with 25 additions and 36 deletions

View File

@@ -299,10 +299,10 @@ struct FmhaFwdAppendKVKernel
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_sk, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}();
const index_t i_sk = __builtin_amdgcn_readfirstlane(i_tile_sk * FmhaPipeline::kTileSizeSk);
// const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kTileSizeS);
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kTileSizeSk);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
@@ -514,9 +514,8 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
/// TODO: use tile idx for q
return make_tile_window(
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
}
else
{
@@ -540,9 +539,8 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
/// TODO: use tile idx for q
return make_tile_window(
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
}
else
{
@@ -570,7 +568,7 @@ struct FmhaFwdAppendKVKernel
}();
return make_tile_window(
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
}
else
{
@@ -595,7 +593,7 @@ struct FmhaFwdAppendKVKernel
}();
return make_tile_window(
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0});
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
}
else
{
@@ -603,31 +601,30 @@ struct FmhaFwdAppendKVKernel
}
}();
/// TODO: use tile idx for q
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<FmhaPipeline::kTileSizeS>{}, number<FmhaPipeline::kTileSizeD>{}),
{0, 0});
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
{kargs.seqlen_k + i_sk, 0});
{kargs.seqlen_k + i_n0, 0});
auto knew_dram_window = make_tile_window(
knew_dram,
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
{i_sk, 0});
{i_n0, 0});
auto v_dram_window = make_tile_window(
v_dram,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kTileSizeSk>{}),
{0, kargs.seqlen_k + i_sk});
{0, kargs.seqlen_k + i_n0});
auto vnew_dram_window = make_tile_window(
vnew_dram,
make_tuple(number<FmhaPipeline::kTileSizeDv>{}, number<FmhaPipeline::kTileSizeSk>{}),
{0, i_sk});
{0, i_n0});
if constexpr(kApplyRoPE)
{
@@ -640,6 +637,8 @@ struct FmhaFwdAppendKVKernel
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0,
smem_ptr,
kargs.rotary_dim);
}
@@ -654,6 +653,8 @@ struct FmhaFwdAppendKVKernel
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0,
smem_ptr);
}
}

View File

@@ -19,35 +19,23 @@ struct FmhaFwdAppendKVTilePartitioner
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_v)
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
assert(ck_tile::integer_divide_ceil(hdim_v, kTileSizeD) == 1);
#ifdef NDEBUG
ignore = hdim_v;
#endif
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk), nhead, batch_size);
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kTileSizeS),
ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk)),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/)
CK_TILE_DEVICE auto operator()()
{
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
const index_t i_tile = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
(void)f;
// const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
}
};