mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
add bf16 for vsa
This commit is contained in:
@@ -33,17 +33,18 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k);
|
||||
|
||||
ck_tile::HostTensor<DataType>
|
||||
vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
ck_tile::HostTensor<DataType>& TK,
|
||||
ck_tile::HostTensor<DataType>& TV,
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<DataType_>& TV,
|
||||
ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
|
||||
ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
|
||||
ck_tile::HostTensor<DataType>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
int batch,
|
||||
int nhead,
|
||||
|
||||
@@ -15,9 +15,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// Define DataType before including the header
|
||||
using DataType = ck_tile::half_t;
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
|
||||
@@ -363,29 +360,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
// Warmup
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
vsa_sparse_attention(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
@@ -394,29 +391,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
vsa_sparse_attention(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
lut_host,
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
@@ -526,10 +523,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
std::cout << "Note: Using bf16 precision" << std::endl;
|
||||
// For bf16, we would need to compile with DataType = ck_tile::bf16_t
|
||||
// For now, run with the compiled DataType
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -6,18 +6,20 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
ck_tile::HostTensor<DataType>
|
||||
vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
ck_tile::HostTensor<DataType>& TK,
|
||||
ck_tile::HostTensor<DataType>& TV,
|
||||
ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
|
||||
ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
|
||||
ck_tile::HostTensor<DataType>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<DataType_>& TV,
|
||||
ck_tile::HostTensor<int32_t>& TKV_block_idx,
|
||||
ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
int batch,
|
||||
int nhead,
|
||||
@@ -32,8 +34,12 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k)
|
||||
{
|
||||
// Determine data type string based on template parameter
|
||||
std::string data_type = "fp16";
|
||||
// DataType is determined at compile time via template
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
@@ -218,3 +224,26 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
vsa_sparse_attention<ck_tile::half_t>(
|
||||
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<int32_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
vsa_sparse_attention<ck_tile::bf16_t>(
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<int32_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
Reference in New Issue
Block a user