From 6cf17bc827fd52bbfbb9ddebda30a6cde68280a4 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 28 Jun 2026 13:24:05 +0000 Subject: [PATCH] Miscellaneous small updates and corrections --- .../example_hstu_attention_fwd.cpp | 8 ++++---- .../hstu_attention_fwd_kernel.hpp | 15 +++++++-------- .../hstu_attention_fwd_splitkv_kernel.hpp | 15 +++++++-------- .../18_hstu_attention/hstu_block_masking.hpp | 8 ++++---- ...hstu_generate_batched_random_number_uint16.cpp | 4 ++-- .../hstu_generate_batched_random_number_uint8.cpp | 4 ++-- .../hstu_generate_jagged_random_number_uint16.cpp | 4 ++-- .../hstu_generate_jagged_random_number_uint8.cpp | 4 ++-- .../hstu_rand_uniform_kernel.hpp | 8 ++------ 9 files changed, 32 insertions(+), 38 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp index b7e10e3618..e1f7c0edfc 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp @@ -463,7 +463,7 @@ bool run_no_group_hstu_forward(const ck_tile::ArgParser& arg_parser, bool is_jag params.p_drop = p_drop; params.philox_seed = PHILOX_SEED; params.philox_offset = PHILOX_OFFSET; - }; + } bool has_dropout = (params.p_drop > 0.0f); ck_tile::HostTensor rand_vals_host( @@ -506,8 +506,8 @@ bool run_no_group_hstu_forward(const ck_tile::ArgParser& arg_parser, bool is_jag rv_params.stride_batch = rand_vals_host.get_strides()[0]; rv_params.philox_seed = PHILOX_SEED; rv_params.philox_offset = PHILOX_OFFSET; - }; - }; + } + } hipStream_t stream; @@ -1023,7 +1023,7 @@ bool run_group_hstu_forward(const ck_tile::ArgParser& arg_parser, int num_group) rv_params.stride_nhead = rand_vals_host.get_strides()[2]; rv_params.philox_seed = PHILOX_SEED; rv_params.philox_offset = PHILOX_OFFSET; - }; + } hipStream_t stream; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 62d38e8e4f..50bde81a4f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -1057,10 +1057,10 @@ struct HstuAttentionFwdKernel }(); const auto [seqlen_k_start, seqlen_k_end] = - mask.template GetTileRangeAlongX( - i_m0, - number{}, - number{}); + mask.GetTileRangeAlongX(bool_constant{}, + i_m0, + number{}, + number{}); if constexpr(!kUseSoftmax) { @@ -1113,10 +1113,9 @@ struct HstuAttentionFwdKernel }(); const auto [seqlen_k_start, seqlen_k_end] = - mask.template GetTileRangeAlongX( - i_m0, - number{}, - number{}); + mask.GetTileRangeAlongX(i_m0, + number{}, + number{}); if constexpr(!kUseSoftmax) { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp index 7c3372f1e6..e31e9a0930 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -1060,10 +1060,10 @@ struct HstuAttentionFwdSplitKVKernel }(); const auto [global_seqlen_k_start, global_seqlen_k_end] = - mask.template GetTileRangeAlongX( - i_m0, - number{}, - number{}); + mask.GetTileRangeAlongX(bool_constant{}, + i_m0, + number{}, + number{}); const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit( global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split); @@ -1119,10 +1119,9 @@ struct HstuAttentionFwdSplitKVKernel }(); const auto [global_seqlen_k_start, global_seqlen_k_end] = - mask.template GetTileRangeAlongX( - i_m0, - number{}, - number{}); + mask.GetTileRangeAlongX(i_m0, + number{}, + number{}); const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit( global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split); diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 2346f56c0d..0bd07324ba 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -79,7 +79,7 @@ struct HstuCrossAttentionBlockMaskWithLocal // i_y is the start offset of the current tile along the seqlen_q dimension template CK_TILE_DEVICE constexpr auto - GetTileRangeAlongX(index_t i_y, number, number) const + GetTileRangeAlongX(bool_constant, index_t i_y, number, number) const { // handle two special cases first if(!is_tile_in_first_split) @@ -327,7 +327,7 @@ struct HstuSelfAttentionBlockMaskWithLocal // i_y is the start offset of the current tile along the seqlen_q dimension template CK_TILE_DEVICE constexpr auto - GetTileRangeAlongX(index_t i_y, number, number) const + GetTileRangeAlongX(bool_constant, index_t i_y, number, number) const { // handle two special cases first if(!is_tile_in_first_split) @@ -563,7 +563,7 @@ struct HstuCrossAttentionBlockMaskNoLocal // to get the loop length along X axis, return index:[start, end), end-start=length // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) // i_y is the start offset of the current tile along the seqlen_q dimension - template + template CK_TILE_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { @@ -699,7 +699,7 @@ struct HstuSelfAttentionBlockMaskNoLocal // to get the loop length along X axis, return index:[start, end), end-start=length // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) // i_y is the start offset of the current tile along the seqlen_q dimension - template + template CK_TILE_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp index a535af0559..251056de2c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp @@ -22,8 +22,8 @@ void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersPa param.stride_batch, {param.philox_seed, param.philox_offset}); - dim3 kGridSize = HstuRandUniformKernel_::GridSize( - param.num_batch, param.num_head, param.seqlen_q, param.seqlen_k); + dim3 kGridSize = + HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.seqlen_q); dim3 kBlockSize = HstuRandUniformKernel_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp index c9c92cd23a..460dae2b00 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp @@ -22,8 +22,8 @@ void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersPar param.stride_batch, {param.philox_seed, param.philox_offset}); - dim3 kGridSize = HstuRandUniformKernel_::GridSize( - param.num_batch, param.num_head, param.seqlen_q, param.seqlen_k); + dim3 kGridSize = + HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.seqlen_q); dim3 kBlockSize = HstuRandUniformKernel_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp index 726f24dc0a..34e57ec8d3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp @@ -21,8 +21,8 @@ void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersPar param.seq_k_offsets_ptr, {param.philox_seed, param.philox_offset}); - dim3 kGridSize = HstuRandUniformKernel_::GridSize( - param.num_batch, param.num_head, param.max_seqlen_q, param.seqlen_k); + dim3 kGridSize = + HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); dim3 kBlockSize = HstuRandUniformKernel_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp index ddcf183a8b..8538cb2752 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp @@ -21,8 +21,8 @@ void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersPara param.seq_k_offsets_ptr, {param.philox_seed, param.philox_offset}); - dim3 kGridSize = HstuRandUniformKernel_::GridSize( - param.num_batch, param.num_head, param.max_seqlen_q, param.seqlen_k); + dim3 kGridSize = + HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.max_seqlen_q); dim3 kBlockSize = HstuRandUniformKernel_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp index 1de7adb930..22a8132510 100644 --- a/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp @@ -156,13 +156,9 @@ struct HstuRandUniformKernel return kargs; } - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t seqlen_k_) + __host__ static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) { - (void)seqlen_k_; // not used at present - // at present, seqlen_k is not splitted by thread-groups return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kMPerBlock), nhead_, batch_size_); }