mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Update to kernel and host interface for generating random numbers
This commit is contained in:
@@ -13,3 +13,11 @@ extern void hstu_attention_group_forward_fp16(HstuAttentionGroupFwdParams& param
|
||||
hipStream_t stream);
|
||||
extern void hstu_attention_group_forward_bf16(HstuAttentionGroupFwdParams& param,
|
||||
hipStream_t stream);
|
||||
extern void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream);
|
||||
extern void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream);
|
||||
extern void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream);
|
||||
extern void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream);
|
||||
|
||||
@@ -138,18 +138,17 @@ struct HstuGenerateRandUniformNumbersParams
|
||||
void* rand_val_ptr;
|
||||
|
||||
ck_tile::index_t num_batches;
|
||||
ck_tile::index_t seqlen_q; // batched mode only
|
||||
ck_tile::index_t seqlen_k; // batched mode only
|
||||
const void* seq_q_offsets_ptr; // jagged mode only
|
||||
const void* seq_k_offsets_ptr; // jagged mode only
|
||||
ck_tile::index_t max_seqlen_q; // jagged mode only
|
||||
ck_tile::index_t seqlen_q; // batched mode only
|
||||
ck_tile::index_t seqlen_k; // batched mode only
|
||||
const void* seq_q_offsets_ptr; // jagged mode only
|
||||
const void* seq_k_offsets_ptr; // jagged mode only
|
||||
ck_tile::index_t max_seqlen_q; // jagged mode only
|
||||
|
||||
ck_tile::index_t num_heads;
|
||||
|
||||
ck_tile::index_t stride_seqlen_q;
|
||||
ck_tile::index_t stride_seqlen_k;
|
||||
ck_tile::index_t stride_nhead;
|
||||
ck_tile::index_t stride_batch; // batched mode only
|
||||
ck_tile::index_t stride_batch; // batched mode only
|
||||
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_rand_uniform_kernel.hpp"
|
||||
|
||||
void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_rand_uniform_kernel.hpp"
|
||||
|
||||
void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream)
|
||||
hipStream_t stream)
|
||||
{
|
||||
// only work for batched mode
|
||||
using HstuRandUniformKernel_ = HstuRandUniformKernel<uint8_t, false>;
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_rand_uniform_kernel.hpp"
|
||||
|
||||
void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream)
|
||||
hipStream_t stream)
|
||||
{
|
||||
// only work for jagged mode
|
||||
using HstuRandUniformKernel_ = HstuRandUniformKernel<uint16_t, true>;
|
||||
|
||||
const auto kargs = HstuRandUniformKernel_::MakeKargs(param.rand_val_ptr,
|
||||
param.seqlen_q,
|
||||
param.seqlen_k,
|
||||
param.num_heads,
|
||||
param.num_batches,
|
||||
param.stride_seqlen_q,
|
||||
param.stride_seqlen_k,
|
||||
param.stride_nhead,
|
||||
param.seqstart_q_ptr,
|
||||
param.seqstart_k_ptr,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.seq_k_offsets_ptr,
|
||||
{param.philox_seed, param.philox_offset});
|
||||
|
||||
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
|
||||
param.num_batches, param.num_heads, param.seqlen_q, param.seqlen_k);
|
||||
param.num_batches, param.num_heads, param.max_seqlen_q, param.seqlen_k);
|
||||
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_rand_uniform_kernel.hpp"
|
||||
|
||||
void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersParams& param,
|
||||
hipStream_t stream)
|
||||
hipStream_t stream)
|
||||
{
|
||||
// only work for jagged mode
|
||||
using HstuRandUniformKernel_ = HstuRandUniformKernel<uint8_t, true>;
|
||||
|
||||
const auto kargs = HstuRandUniformKernel_::MakeKargs(param.rand_val_ptr,
|
||||
param.seqlen_q,
|
||||
param.seqlen_k,
|
||||
param.num_heads,
|
||||
param.num_batches,
|
||||
param.stride_seqlen_q,
|
||||
param.stride_seqlen_k,
|
||||
param.stride_nhead,
|
||||
param.seqstart_q_ptr,
|
||||
param.seqstart_k_ptr,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.seq_k_offsets_ptr,
|
||||
{param.philox_seed, param.philox_offset});
|
||||
|
||||
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
|
||||
param.num_batches, param.num_heads, param.seqlen_q, param.seqlen_k);
|
||||
param.num_batches, param.num_heads, param.max_seqlen_q, param.seqlen_k);
|
||||
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -94,7 +94,9 @@ struct HstuRandUniformKernel
|
||||
|
||||
struct HstuRandUniformJaggedKargs : HstuRandUniformCommonKargs
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seq_q_offsets_ptr;
|
||||
// provide seqlen_k for each batch, aligned with jagged k/v tensor lay-out
|
||||
const int32_t* seq_k_offsets_ptr;
|
||||
};
|
||||
|
||||
using Kargs =
|
||||
@@ -133,12 +135,12 @@ struct HstuRandUniformKernel
|
||||
ck_tile::index_t num_batches,
|
||||
ck_tile::index_t stride_seqlen_q,
|
||||
ck_tile::index_t stride_nhead,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_k_offsets_ptr,
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{rand_val_ptr,
|
||||
-1, // seqlen_q will be update in the kernel
|
||||
-1, // seqlen_q will be update in the kernel
|
||||
-1, // seqlen_k will be update in the kernel
|
||||
num_heads,
|
||||
num_batches,
|
||||
@@ -146,8 +148,8 @@ struct HstuRandUniformKernel
|
||||
stride_nhead,
|
||||
std::get<0>(drop_seed_offset),
|
||||
std::get<1>(drop_seed_offset)},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
|
||||
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr)};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -233,21 +235,21 @@ struct HstuRandUniformKernel
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t query_start = kargs.seq_q_offsets_ptr[i_batch];
|
||||
|
||||
batch_offset_randval = query_start * kargs.stride_seqlen_q;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
const auto adjusted_seq_q_offsets_ptr = kargs.seq_q_offsets_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seq_q_offsets_ptr[1] - adjusted_seq_q_offsets_ptr[0];
|
||||
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
const auto adjusted_seq_k_offsets_ptr = kargs.seq_k_offsets_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seq_k_offsets_ptr[1] - adjusted_seq_k_offsets_ptr[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user