diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_api.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_api.hpp index 0485a6b3c1..519d4691a7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_api.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_api.hpp @@ -13,3 +13,11 @@ extern void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param hipStream_t stream); extern void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param, hipStream_t stream); +extern void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersParams& param, + hipStream_t stream); +extern void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersParams& param, + hipStream_t stream); +extern void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersParams& param, + hipStream_t stream); +extern void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersParams& param, + hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index a89f33c8da..d1baa8cd74 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -138,18 +138,17 @@ struct HstuGenerateRandUniformNumbersParams void* rand_val_ptr; ck_tile::index_t num_batches; - ck_tile::index_t seqlen_q; // batched mode only - ck_tile::index_t seqlen_k; // batched mode only - const void* seq_q_offsets_ptr; // jagged mode only - const void* seq_k_offsets_ptr; // jagged mode only - ck_tile::index_t max_seqlen_q; // jagged mode only + ck_tile::index_t seqlen_q; // batched mode only + ck_tile::index_t seqlen_k; // batched mode only + const void* seq_q_offsets_ptr; // jagged mode only + const void* seq_k_offsets_ptr; // jagged mode only + ck_tile::index_t max_seqlen_q; // jagged mode only ck_tile::index_t num_heads; ck_tile::index_t stride_seqlen_q; - ck_tile::index_t stride_seqlen_k; ck_tile::index_t stride_nhead; - ck_tile::index_t stride_batch; // batched mode only + ck_tile::index_t stride_batch; // batched mode only uint64_t philox_seed; uint64_t philox_offset; diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp index 263d0e4382..479f5576fa 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint16.cpp @@ -1,6 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "hstu_attention_params.hpp" +#include "hstu_rand_uniform_kernel.hpp" + void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersParams& param, hipStream_t stream) { diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp index ac62d8f241..840637101d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_batched_random_number_uint8.cpp @@ -1,8 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "hstu_attention_params.hpp" +#include "hstu_rand_uniform_kernel.hpp" + void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersParams& param, - hipStream_t stream) + hipStream_t stream) { // only work for batched mode using HstuRandUniformKernel_ = HstuRandUniformKernel; diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp index 059f364229..23a29279c3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint16.cpp @@ -1,26 +1,26 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "hstu_attention_params.hpp" +#include "hstu_rand_uniform_kernel.hpp" + void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersParams& param, - hipStream_t stream) + hipStream_t stream) { // only work for jagged mode using HstuRandUniformKernel_ = HstuRandUniformKernel; const auto kargs = HstuRandUniformKernel_::MakeKargs(param.rand_val_ptr, - param.seqlen_q, - param.seqlen_k, param.num_heads, param.num_batches, param.stride_seqlen_q, - param.stride_seqlen_k, param.stride_nhead, - param.seqstart_q_ptr, - param.seqstart_k_ptr, + param.seq_q_offsets_ptr, + param.seq_k_offsets_ptr, {param.philox_seed, param.philox_offset}); dim3 kGridSize = HstuRandUniformKernel_::GridSize( - param.num_batches, param.num_heads, param.seqlen_q, param.seqlen_k); + param.num_batches, param.num_heads, param.max_seqlen_q, param.seqlen_k); dim3 kBlockSize = HstuRandUniformKernel_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp index 3748c9bda4..7ceae795cf 100644 --- a/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp +++ b/example/ck_tile/18_hstu_attention/hstu_generate_jagged_random_number_uint8.cpp @@ -1,26 +1,26 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "hstu_attention_params.hpp" +#include "hstu_rand_uniform_kernel.hpp" + void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersParams& param, - hipStream_t stream) + hipStream_t stream) { // only work for jagged mode using HstuRandUniformKernel_ = HstuRandUniformKernel; const auto kargs = HstuRandUniformKernel_::MakeKargs(param.rand_val_ptr, - param.seqlen_q, - param.seqlen_k, param.num_heads, param.num_batches, param.stride_seqlen_q, - param.stride_seqlen_k, param.stride_nhead, - param.seqstart_q_ptr, - param.seqstart_k_ptr, + param.seq_q_offsets_ptr, + param.seq_k_offsets_ptr, {param.philox_seed, param.philox_offset}); dim3 kGridSize = HstuRandUniformKernel_::GridSize( - param.num_batches, param.num_heads, param.seqlen_q, param.seqlen_k); + param.num_batches, param.num_heads, param.max_seqlen_q, param.seqlen_k); dim3 kBlockSize = HstuRandUniformKernel_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu; diff --git a/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp index d9fab04328..6cf31ba8d7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_rand_uniform_kernel.hpp @@ -94,7 +94,9 @@ struct HstuRandUniformKernel struct HstuRandUniformJaggedKargs : HstuRandUniformCommonKargs { - const int32_t* seqstart_q_ptr; + const int32_t* seq_q_offsets_ptr; + // provide seqlen_k for each batch, aligned with jagged k/v tensor lay-out + const int32_t* seq_k_offsets_ptr; }; using Kargs = @@ -133,12 +135,12 @@ struct HstuRandUniformKernel ck_tile::index_t num_batches, ck_tile::index_t stride_seqlen_q, ck_tile::index_t stride_nhead, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, + const void* seq_q_offsets_ptr, + const void* seq_k_offsets_ptr, std::tuple drop_seed_offset) { Kargs kargs{{rand_val_ptr, - -1, // seqlen_q will be update in the kernel + -1, // seqlen_q will be update in the kernel -1, // seqlen_k will be update in the kernel num_heads, num_batches, @@ -146,8 +148,8 @@ struct HstuRandUniformKernel stride_nhead, std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)}, - reinterpret_cast(seqstart_q_ptr), - reinterpret_cast(seqstart_q_ptr)}; + reinterpret_cast(seq_q_offsets_ptr), + reinterpret_cast(seq_q_offsets_ptr)}; return kargs; } @@ -233,21 +235,21 @@ struct HstuRandUniformKernel if constexpr(kIsJagged) { // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch]; batch_offset_randval = query_start * kargs.stride_seqlen_q; // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const auto adjusted_seq_q_offsets_ptr = kargs.seq_q_offsets_ptr + i_batch; + kargs.seqlen_q = adjusted_seq_q_offsets_ptr[1] - adjusted_seq_q_offsets_ptr[0]; if(kargs.seqlen_q <= i_m0) { return; } - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + const auto adjusted_seq_k_offsets_ptr = kargs.seq_k_offsets_ptr + i_batch; + kargs.seqlen_k = adjusted_seq_k_offsets_ptr[1] - adjusted_seq_k_offsets_ptr[0]; } else {