add jenga support bf16

This commit is contained in:
Jiangyon
2025-12-16 06:57:42 +00:00
parent 997ec8f89c
commit 29d96a90f0
4 changed files with 147 additions and 93 deletions

View File

@@ -6,17 +6,19 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
ck_tile::HostTensor<DataType>
jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
ck_tile::HostTensor<DataType>& TK,
ck_tile::HostTensor<DataType>& TV,
ck_tile::HostTensor<DataType>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType>& Y,
std::optional<ck_tile::HostTensor<DataType>> bias,
std::optional<ck_tile::HostTensor<DataType>> lse,
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<DataType_>& TV,
ck_tile::HostTensor<DataType_>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
std::optional<ck_tile::HostTensor<DataType_>> bias,
std::optional<ck_tile::HostTensor<DataType_>> lse,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
int bias_type,
int batch,
int nhead,
@@ -31,8 +33,12 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
int max_seqlen_q,
int max_seqlen_k)
{
// Determine data type string based on template parameter
std::string data_type = "fp16";
// DataType is determined at compile time via template
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
@@ -208,7 +214,30 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
fmha_jenga_fwd(fmha_traits, args, stream_config);
// Copy output back to host
Y = o_buf.ToHost<DataType>();
Y = o_buf.ToHost<DataType_>();
return Y;
}
// Explicit template instantiations
template ck_tile::HostTensor<ck_tile::half_t>
jenga_sparse_attention<ck_tile::half_t>(
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
ck_tile::HostTensor<ck_tile::half_t>&,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
template ck_tile::HostTensor<ck_tile::bf16_t>
jenga_sparse_attention<ck_tile::bf16_t>(
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
int, int, int, int, int, int, int, int, int, bool, bool, int, int);

View File

@@ -7,18 +7,17 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
using DataType = ck_tile::half_t;
ck_tile::HostTensor<DataType>
jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
ck_tile::HostTensor<DataType>& TK,
ck_tile::HostTensor<DataType>& TV,
ck_tile::HostTensor<DataType>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType>& Y,
std::optional<ck_tile::HostTensor<DataType>> bias,
std::optional<ck_tile::HostTensor<DataType>> lse,
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<DataType_>& TV,
ck_tile::HostTensor<DataType_>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
std::optional<ck_tile::HostTensor<DataType_>> bias,
std::optional<ck_tile::HostTensor<DataType_>> lse,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
int bias_type,
int batch,
int nhead,

View File

@@ -131,6 +131,15 @@ void reference_blocked_attention(
}
}
// Get error tolerance based on data type
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2; // Higher tolerance for bf16/fp16
return ck_tile::make_tuple(rtol, atol);
}
// ============================================================================
// Command line argument parser
// ============================================================================
@@ -148,13 +157,15 @@ auto create_args(int argc, char* argv[])
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi")
.insert("lse", "0", "0:not store lse, 1:store lse")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations");
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -163,29 +174,29 @@ auto create_args(int argc, char* argv[])
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
using T = DataType; // Use DataType defined in header (half_t)
// Parse arguments
int do_validation = arg_parser.get_int("v");
int mode = arg_parser.get_int("mode");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t block_size = arg_parser.get_int("block_size");
float sparsity = arg_parser.get_float("sparsity");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
int bias_type = arg_parser.get_int("bias");
bool store_lse = arg_parser.get_bool("lse");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int do_validation = arg_parser.get_int("v");
int mode = arg_parser.get_int("mode");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t block_size = arg_parser.get_int("block_size");
float sparsity = arg_parser.get_float("sparsity");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
int bias_type = arg_parser.get_int("bias");
[[maybe_unused]] bool store_lse = arg_parser.get_bool("lse");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
[[maybe_unused]] int kname = arg_parser.get_int("kname");
// Handle default values
if(nhead_k < 0)
@@ -301,28 +312,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
// Warmup
for(int i = 0; i < warmup; ++i)
{
jenga_sparse_attention(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
}
// Benchmark
@@ -331,28 +342,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
for(int i = 0; i < repeat; ++i)
{
jenga_sparse_attention(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
bias_opt,
lse_opt,
seqstart_q_opt,
seqstart_k_opt,
bias_type,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
mode,
i_perm,
o_perm,
seqlen_q,
seqlen_k);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
@@ -389,8 +400,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
scale);
// Compare results
double rtol = 1e-2;
double atol = 4e-2;
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
@@ -450,6 +460,22 @@ int main(int argc, char* argv[])
return -1;
}
bool test_result = run_test(arg_parser);
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -220,7 +220,7 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
fmha_vsa_fwd(fmha_traits, args, stream_config);
// Copy output back to host
Y = o_buf.ToHost<DataType>();
Y = o_buf.ToHost<DataType_>();
return Y;
}