From 0306a1e6ab74117a4661f8857ff7010c10b9d82b Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 17 Jul 2025 04:48:55 +0000 Subject: [PATCH] Re-org the kernel parameters in HstuAttentionFwdBatchModeBaseKargs and HstuAttentionFwdJaggModeBaseKargs --- .../hstu_attention_fwd_kernel.hpp | 230 ++++++++++-------- 1 file changed, 133 insertions(+), 97 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 41ed0ebde2..a3dae30ea4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -62,34 +62,80 @@ struct HstuAttentionFwdKernel // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. - struct HstuAttentionFwdCommonKargs + struct HstuAttentionFwdBatchModeBaseKargs { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + + const int32_t* num_targets_ptr; + const void* q_ptr; const void* k_ptr; const void* v_ptr; void* o_ptr; - ck_tile::index_t seqlen; - ck_tile::index_t hdim_qk; - ck_tile::index_t hdim_v; - - ck_tile::index_t num_head; - float scale_s; - - ck_tile::index_t seq_stride_q; - ck_tile::index_t seq_stride_k; - ck_tile::index_t seq_stride_v; - ck_tile::index_t seq_stride_o; - ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; - const int32_t* num_targets_ptr; + ck_tile::index_t seqlen; + ck_tile::index_t hdim_qk; + ck_tile::index_t hdim_v; + + ck_tile::index_t seq_stride_q; + ck_tile::index_t seq_stride_k; + ck_tile::index_t seq_stride_v; + ck_tile::index_t seq_stride_o; + + ck_tile::index_t num_head; + float scale_s; + ck_tile::index_t max_seqlen; + ck_tile::index_t contextual_seqlen; }; + struct HstuAttentionFwdJaggModeBaseKargs + { + const int32_t* seq_offsets_ptr; + + ck_tile::index_t seq_stride_q; + ck_tile::index_t seq_stride_k; + ck_tile::index_t seq_stride_v; + ck_tile::index_t seq_stride_o; + + const int32_t* num_targets_ptr; + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + + ck_tile::index_t hdim_qk; + ck_tile::index_t hdim_v; + + ck_tile::index_t seqlen; + + ck_tile::index_t num_head; + float scale_s; + ck_tile::index_t max_seqlen; + + ck_tile::index_t contextual_seqlen; + }; + + struct HstuAttentionFwdMaskKargs + { + ck_tile::index_t window_size; + ck_tile::index_t min_full_attn_seqlen; + }; + struct HstuAttentionFwdCommonBiasKargs { const void* bias_ptr = nullptr; @@ -102,12 +148,6 @@ struct HstuAttentionFwdKernel ck_tile::index_t batch_stride_bias = 0; }; - struct HstuAttentionFwdMaskKargs - { - ck_tile::index_t window_size; - ck_tile::index_t min_full_attn_seqlen; - }; - struct HstuAttentionFwdDropoutSeedOffset { uint64_t drop_seed; @@ -131,38 +171,30 @@ struct HstuAttentionFwdKernel uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); }; - struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdCommonKargs, - std::conditional_t>, + struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdBatchModeBaseKargs, std::conditional_t>, + std::conditional_t>, std::conditional_t> { - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_o; - - ck_tile::index_t max_seqlen; }; - struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs, - std::conditional_t>, + struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdJaggModeBaseKargs, std::conditional_t>, + std::conditional_t>, std::conditional_t> { - const int32_t* seq_offsets_ptr; - ck_tile::index_t max_seqlen; }; using Kargs = std:: @@ -202,34 +234,41 @@ struct HstuAttentionFwdKernel float p_drop, const std::pair& drop_seed_offset) { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - o_ptr, - seqlen, - hdim_qk, - hdim_v, - num_head, - -scale_s, - seq_stride_q, - seq_stride_k, - seq_stride_v, - seq_stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o, - reinterpret_cast(num_targets_ptr), - contextual_seqlen}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for dropout - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_o, - seqlen}; // max_seqlen + Kargs kargs{ + {batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o, + reinterpret_cast(num_targets_ptr), + q_ptr, + k_ptr, + v_ptr, + o_ptr, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + seqlen, + hdim_qk, + hdim_v, + seq_stride_q, + seq_stride_k, + seq_stride_v, + seq_stride_o, + num_head, + -scale_s, + seqlen, // max_seqlen + contextual_seqlen}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for bias + {}, // placeholder for dropout + }; + if constexpr(kHasLocalMask) + { + kargs.window_size = window_size; + kargs.min_full_attn_seqlen = min_full_attn_seqlen; + } if constexpr(kHasBias) { kargs.bias_ptr = bias_ptr; @@ -237,11 +276,6 @@ struct HstuAttentionFwdKernel kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; } - if constexpr(kHasLocalMask) - { - kargs.window_size = window_size; - kargs.min_full_attn_seqlen = min_full_attn_seqlen; - } if constexpr(kHasDropout) { auto seed = std::get<0>(drop_seed_offset); @@ -350,42 +384,44 @@ struct HstuAttentionFwdKernel float p_drop, const std::pair& drop_seed_offset) { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - o_ptr, - -1, // seqlen will be updated by another pointer - hdim_qk, - hdim_v, - num_head, - -scale_s, - seq_stride_q, - seq_stride_k, - seq_stride_v, - seq_stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o, - reinterpret_cast(num_targets_ptr), - contextual_seqlen}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for dropout - reinterpret_cast(seq_offsets_ptr), - max_seqlen}; + Kargs kargs{ + {reinterpret_cast(seq_offsets_ptr), + seq_stride_q, + seq_stride_k, + seq_stride_v, + seq_stride_o, + reinterpret_cast(num_targets_ptr), + q_ptr, + k_ptr, + v_ptr, + o_ptr, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + hdim_qk, + hdim_v, + -1, // seqlen will be updated by another pointer + num_head, + -scale_s, + max_seqlen, + contextual_seqlen}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for bias + {}, // placeholder for dropout + }; + if constexpr(kHasLocalMask) + { + kargs.window_size = window_size; + kargs.min_full_attn_seqlen = min_full_attn_seqlen; + } if constexpr(kHasBias) { kargs.bias_ptr = bias_ptr; kargs.seq_stride_bias = seq_stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } - if constexpr(kHasLocalMask) - { - kargs.window_size = window_size; - kargs.min_full_attn_seqlen = min_full_attn_seqlen; - } if constexpr(kHasDropout) { auto seed = std::get<0>(drop_seed_offset);