diff --git a/example/ck_tile/18_hstu_attention/CMakeLists.txt b/example/ck_tile/18_hstu_attention/CMakeLists.txt index e8e39afa65..c62df256f3 100644 --- a/example/ck_tile/18_hstu_attention/CMakeLists.txt +++ b/example/ck_tile/18_hstu_attention/CMakeLists.txt @@ -12,6 +12,10 @@ set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS) list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +if (DEFINED ENV{ASSUME_LEAST_VARIED_SEQLEN}) + list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -DHSTU_SCHED_BATCH_AS_FIRST_GRID_DIM=1) +endif() + target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 47c404742e..1f1f032981 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -13,6 +13,10 @@ #include "hstu_block_masking.hpp" +#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM +#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 0 +#endif + // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -458,16 +462,29 @@ struct HstuAttentionFwdKernel { if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim) { +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + return dim3(batch_size_, + nhead_, + ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1)); +#else return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) * ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1), nhead_, batch_size_); +#endif } else { +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + return dim3(batch_size_, + nhead_, + ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0)); +#else return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0), nhead_, batch_size_); +#endif } } @@ -478,9 +495,15 @@ struct HstuAttentionFwdKernel const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1); - const index_t i_block = blockIdx.x; +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + const index_t i_batch = blockIdx.x; const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_block = blockIdx.z; +#else + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; +#endif const auto f = [](index_t dividend, index_t divisor) { index_t quotient = dividend / divisor; @@ -488,17 +511,33 @@ struct HstuAttentionFwdKernel return ck_tile::make_tuple(quotient, modulus); }; +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + i_tile_m = gridDim.z / num_tile_n1 - 1 - i_tile_m; +#else const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); +#endif return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); } else { - const index_t i_block = blockIdx.x; +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + const index_t i_batch = blockIdx.x; const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_block = blockIdx.z; +#else + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; +#endif - const index_t i_tile_m = i_block; +#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM + index_t i_tile_m = i_block; + i_tile_m = gridDim.z - 1 - i_tile_m; +#else + const index_t i_tile_m = i_block; +#endif const index_t i_tile_n = 0; return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);