Enable BATCH_AS_FIRST_GRID_DIM grid-scheduling and use ASSUME_LEAST_VARIED_SEQLEN for building control

This commit is contained in:
Qianfeng Zhang
2025-06-10 15:43:19 +00:00
parent 4632d30cc0
commit 08886e99d5
2 changed files with 48 additions and 5 deletions

View File

@@ -12,6 +12,10 @@ set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS)
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
if (DEFINED ENV{ASSUME_LEAST_VARIED_SEQLEN})
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -DHSTU_SCHED_BATCH_AS_FIRST_GRID_DIM=1)
endif()
target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated

View File

@@ -13,6 +13,10 @@
#include "hstu_block_masking.hpp"
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 0
#endif
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
@@ -458,16 +462,29 @@ struct HstuAttentionFwdKernel
{
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
{
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
return dim3(batch_size_,
nhead_,
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1));
#else
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1),
nhead_,
batch_size_);
#endif
}
else
{
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
return dim3(batch_size_,
nhead_,
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0));
#else
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0),
nhead_,
batch_size_);
#endif
}
}
@@ -478,9 +495,15 @@ struct HstuAttentionFwdKernel
const index_t num_tile_n1 =
ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1);
const index_t i_block = blockIdx.x;
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
const index_t i_batch = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const index_t i_block = blockIdx.z;
#else
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
#endif
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
@@ -488,17 +511,33 @@ struct HstuAttentionFwdKernel
return ck_tile::make_tuple(quotient, modulus);
};
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
i_tile_m = gridDim.z / num_tile_n1 - 1 - i_tile_m;
#else
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
#endif
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
const index_t i_block = blockIdx.x;
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
const index_t i_batch = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const index_t i_block = blockIdx.z;
#else
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
#endif
const index_t i_tile_m = i_block;
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
index_t i_tile_m = i_block;
i_tile_m = gridDim.z - 1 - i_tile_m;
#else
const index_t i_tile_m = i_block;
#endif
const index_t i_tile_n = 0;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);