Update the GridSize() and GetTileIndex() in hstu kernel

This commit is contained in:
Qianfeng Zhang
2025-04-26 10:01:23 +00:00
parent 1b463e915d
commit 95c93ba92e

View File

@@ -448,32 +448,53 @@ struct HstuAttentionFwdKernel
ck_tile::index_t seqlen_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1),
nhead_,
batch_size_);
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
{
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1),
nhead_,
batch_size_);
}
else
{
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0),
nhead_,
batch_size_);
}
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 =
ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1);
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
{
const index_t num_tile_n1 =
ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const index_t i_tile_m = i_block;
const index_t i_tile_n = 0;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }