add bf16 for vsa

This commit is contained in:
Jiangyon
2025-12-16 06:18:13 +00:00
parent 3b00e4022d
commit 997ec8f89c
3 changed files with 98 additions and 74 deletions

View File

@@ -33,17 +33,18 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
int max_seqlen_q,
int max_seqlen_k);
ck_tile::HostTensor<DataType>
vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
ck_tile::HostTensor<DataType>& TK,
ck_tile::HostTensor<DataType>& TV,
template <typename DataType_>
ck_tile::HostTensor<DataType_>
vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<DataType_>& TV,
ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
ck_tile::HostTensor<DataType>& Y,
std::optional<ck_tile::HostTensor<DataType>> bias,
std::optional<ck_tile::HostTensor<DataType>> lse,
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
ck_tile::HostTensor<DataType_>& Y,
std::optional<ck_tile::HostTensor<DataType_>> bias,
std::optional<ck_tile::HostTensor<DataType_>> lse,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
int bias_type,
int batch,
int nhead,

View File

@@ -15,9 +15,6 @@
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
// Define DataType before including the header
using DataType = ck_tile::half_t;
#include "jenga_sparse_attention.h"
#include "fmha_fwd_trek.hpp"
@@ -363,29 +360,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
// Warmup
for(int i = 0; i < warmup; ++i)
{
vsa_sparse_attention(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
}
// Benchmark
@@ -394,29 +391,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
for(int i = 0; i < repeat; ++i)
{
vsa_sparse_attention(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
@@ -526,10 +523,7 @@ int main(int argc, char* argv[])
}
else if(prec == "bf16")
{
std::cout << "Note: Using bf16 precision" << std::endl;
// For bf16, we would need to compile with DataType = ck_tile::bf16_t
// For now, run with the compiled DataType
test_result = run_test<ck_tile::half_t>(arg_parser);
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{

View File

@@ -6,18 +6,20 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
ck_tile::HostTensor<DataType>
vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
ck_tile::HostTensor<DataType>& TK,
ck_tile::HostTensor<DataType>& TV,
ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
ck_tile::HostTensor<DataType>& Y,
std::optional<ck_tile::HostTensor<DataType>> bias,
std::optional<ck_tile::HostTensor<DataType>> lse,
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
template <typename DataType_>
ck_tile::HostTensor<DataType_>
vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<DataType_>& TV,
ck_tile::HostTensor<int32_t>& TKV_block_idx,
ck_tile::HostTensor<int32_t>& TKV_blocks,
ck_tile::HostTensor<DataType_>& Y,
std::optional<ck_tile::HostTensor<DataType_>> bias,
std::optional<ck_tile::HostTensor<DataType_>> lse,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
int bias_type,
int batch,
int nhead,
@@ -32,8 +34,12 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
int max_seqlen_q,
int max_seqlen_k)
{
// Determine data type string based on template parameter
std::string data_type = "fp16";
// DataType is determined at compile time via template
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
@@ -218,3 +224,26 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
return Y;
}
// Explicit template instantiations
template ck_tile::HostTensor<ck_tile::half_t>
vsa_sparse_attention<ck_tile::half_t>(
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<int32_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
template ck_tile::HostTensor<ck_tile::bf16_t>
vsa_sparse_attention<ck_tile::bf16_t>(
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<int32_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
int, int, int, int, int, int, int, int, int, bool, bool, int, int);