Update to kernel and host interface for generating random numbers

This commit is contained in:
Qianfeng Zhang
2026-06-24 08:15:23 +00:00
parent 8a4ec6382d
commit 58fbfe766e
7 changed files with 48 additions and 33 deletions

View File

@@ -13,3 +13,11 @@ extern void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param
hipStream_t stream);
extern void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param,
hipStream_t stream);
extern void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream);
extern void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream);
extern void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream);
extern void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream);

View File

@@ -138,18 +138,17 @@ struct HstuGenerateRandUniformNumbersParams
void* rand_val_ptr;
ck_tile::index_t num_batches;
ck_tile::index_t seqlen_q; // batched mode only
ck_tile::index_t seqlen_k; // batched mode only
const void* seq_q_offsets_ptr; // jagged mode only
const void* seq_k_offsets_ptr; // jagged mode only
ck_tile::index_t max_seqlen_q; // jagged mode only
ck_tile::index_t seqlen_q; // batched mode only
ck_tile::index_t seqlen_k; // batched mode only
const void* seq_q_offsets_ptr; // jagged mode only
const void* seq_k_offsets_ptr; // jagged mode only
ck_tile::index_t max_seqlen_q; // jagged mode only
ck_tile::index_t num_heads;
ck_tile::index_t stride_seqlen_q;
ck_tile::index_t stride_seqlen_k;
ck_tile::index_t stride_nhead;
ck_tile::index_t stride_batch; // batched mode only
ck_tile::index_t stride_batch; // batched mode only
uint64_t philox_seed;
uint64_t philox_offset;

View File

@@ -1,6 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "hstu_attention_params.hpp"
#include "hstu_rand_uniform_kernel.hpp"
void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream)
{

View File

@@ -1,8 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "hstu_attention_params.hpp"
#include "hstu_rand_uniform_kernel.hpp"
void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream)
hipStream_t stream)
{
// only work for batched mode
using HstuRandUniformKernel_ = HstuRandUniformKernel<uint8_t, false>;

View File

@@ -1,26 +1,26 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "hstu_attention_params.hpp"
#include "hstu_rand_uniform_kernel.hpp"
void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream)
hipStream_t stream)
{
// only work for jagged mode
using HstuRandUniformKernel_ = HstuRandUniformKernel<uint16_t, true>;
const auto kargs = HstuRandUniformKernel_::MakeKargs(param.rand_val_ptr,
param.seqlen_q,
param.seqlen_k,
param.num_heads,
param.num_batches,
param.stride_seqlen_q,
param.stride_seqlen_k,
param.stride_nhead,
param.seqstart_q_ptr,
param.seqstart_k_ptr,
param.seq_q_offsets_ptr,
param.seq_k_offsets_ptr,
{param.philox_seed, param.philox_offset});
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
param.num_batches, param.num_heads, param.seqlen_q, param.seqlen_k);
param.num_batches, param.num_heads, param.max_seqlen_q, param.seqlen_k);
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;

View File

@@ -1,26 +1,26 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "hstu_attention_params.hpp"
#include "hstu_rand_uniform_kernel.hpp"
void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
hipStream_t stream)
hipStream_t stream)
{
// only work for jagged mode
using HstuRandUniformKernel_ = HstuRandUniformKernel<uint8_t, true>;
const auto kargs = HstuRandUniformKernel_::MakeKargs(param.rand_val_ptr,
param.seqlen_q,
param.seqlen_k,
param.num_heads,
param.num_batches,
param.stride_seqlen_q,
param.stride_seqlen_k,
param.stride_nhead,
param.seqstart_q_ptr,
param.seqstart_k_ptr,
param.seq_q_offsets_ptr,
param.seq_k_offsets_ptr,
{param.philox_seed, param.philox_offset});
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
param.num_batches, param.num_heads, param.seqlen_q, param.seqlen_k);
param.num_batches, param.num_heads, param.max_seqlen_q, param.seqlen_k);
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;

View File

@@ -94,7 +94,9 @@ struct HstuRandUniformKernel
struct HstuRandUniformJaggedKargs : HstuRandUniformCommonKargs
{
const int32_t* seqstart_q_ptr;
const int32_t* seq_q_offsets_ptr;
// provide seqlen_k for each batch, aligned with jagged k/v tensor lay-out
const int32_t* seq_k_offsets_ptr;
};
using Kargs =
@@ -133,12 +135,12 @@ struct HstuRandUniformKernel
ck_tile::index_t num_batches,
ck_tile::index_t stride_seqlen_q,
ck_tile::index_t stride_nhead,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seq_q_offsets_ptr,
const void* seq_k_offsets_ptr,
std::tuple<uint64_t, uint64_t> drop_seed_offset)
{
Kargs kargs{{rand_val_ptr,
-1, // seqlen_q will be update in the kernel
-1, // seqlen_q will be update in the kernel
-1, // seqlen_k will be update in the kernel
num_heads,
num_batches,
@@ -146,8 +148,8 @@ struct HstuRandUniformKernel
stride_nhead,
std::get<0>(drop_seed_offset),
std::get<1>(drop_seed_offset)},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr)};
return kargs;
}
@@ -233,21 +235,21 @@ struct HstuRandUniformKernel
if constexpr(kIsJagged)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch];
batch_offset_randval = query_start * kargs.stride_seqlen_q;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
const auto adjusted_seq_q_offsets_ptr = kargs.seq_q_offsets_ptr + i_batch;
kargs.seqlen_q = adjusted_seq_q_offsets_ptr[1] - adjusted_seq_q_offsets_ptr[0];
if(kargs.seqlen_q <= i_m0)
{
return;
}
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
const auto adjusted_seq_k_offsets_ptr = kargs.seq_k_offsets_ptr + i_batch;
kargs.seqlen_k = adjusted_seq_k_offsets_ptr[1] - adjusted_seq_k_offsets_ptr[0];
}
else
{