mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Enable BATCH_AS_FIRST_GRID_DIM grid-scheduling and use ASSUME_LEAST_VARIED_SEQLEN for building control
This commit is contained in:
@@ -12,6 +12,10 @@ set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS)
|
||||
|
||||
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
|
||||
if (DEFINED ENV{ASSUME_LEAST_VARIED_SEQLEN})
|
||||
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -DHSTU_SCHED_BATCH_AS_FIRST_GRID_DIM=1)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
|
||||
@@ -13,6 +13,10 @@
|
||||
|
||||
#include "hstu_block_masking.hpp"
|
||||
|
||||
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 0
|
||||
#endif
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
|
||||
@@ -458,16 +462,29 @@ struct HstuAttentionFwdKernel
|
||||
{
|
||||
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kSubQKHeaddim)
|
||||
{
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
return dim3(batch_size_,
|
||||
nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1));
|
||||
#else
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
return dim3(batch_size_,
|
||||
nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0));
|
||||
#else
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -478,9 +495,15 @@ struct HstuAttentionFwdKernel
|
||||
const index_t num_tile_n1 =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
const index_t i_batch = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
const index_t i_block = blockIdx.z;
|
||||
#else
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
#endif
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
@@ -488,17 +511,33 @@ struct HstuAttentionFwdKernel
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
i_tile_m = gridDim.z / num_tile_n1 - 1 - i_tile_m;
|
||||
#else
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
#endif
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
const index_t i_batch = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
const index_t i_block = blockIdx.z;
|
||||
#else
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
#endif
|
||||
|
||||
const index_t i_tile_m = i_block;
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
index_t i_tile_m = i_block;
|
||||
i_tile_m = gridDim.z - 1 - i_tile_m;
|
||||
#else
|
||||
const index_t i_tile_m = i_block;
|
||||
#endif
|
||||
const index_t i_tile_n = 0;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
|
||||
Reference in New Issue
Block a user