mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
add jenga support bf16
This commit is contained in:
@@ -6,17 +6,19 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
ck_tile::HostTensor<DataType>
|
||||
jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
ck_tile::HostTensor<DataType>& TK,
|
||||
ck_tile::HostTensor<DataType>& TV,
|
||||
ck_tile::HostTensor<DataType>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<DataType_>& TV,
|
||||
ck_tile::HostTensor<DataType_>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
int batch,
|
||||
int nhead,
|
||||
@@ -31,8 +33,12 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k)
|
||||
{
|
||||
// Determine data type string based on template parameter
|
||||
std::string data_type = "fp16";
|
||||
// DataType is determined at compile time via template
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
@@ -208,7 +214,30 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
fmha_jenga_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
// Copy output back to host
|
||||
Y = o_buf.ToHost<DataType>();
|
||||
Y = o_buf.ToHost<DataType_>();
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
jenga_sparse_attention<ck_tile::half_t>(
|
||||
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&, ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
jenga_sparse_attention<ck_tile::bf16_t>(
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&, ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
@@ -7,18 +7,17 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
using DataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::HostTensor<DataType>
|
||||
jenga_sparse_attention(ck_tile::HostTensor<DataType>& TQ,
|
||||
ck_tile::HostTensor<DataType>& TK,
|
||||
ck_tile::HostTensor<DataType>& TV,
|
||||
ck_tile::HostTensor<DataType>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType>> seqstart_k,
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<DataType_>& TV,
|
||||
ck_tile::HostTensor<DataType_>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
int batch,
|
||||
int nhead,
|
||||
|
||||
@@ -131,6 +131,15 @@ void reference_blocked_attention(
|
||||
}
|
||||
}
|
||||
|
||||
// Get error tolerance based on data type
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2; // Higher tolerance for bf16/fp16
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
@@ -148,13 +157,15 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
|
||||
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi")
|
||||
.insert("lse", "0", "0:not store lse, 1:store lse")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations");
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -163,29 +174,29 @@ auto create_args(int argc, char* argv[])
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using T = DataType; // Use DataType defined in header (half_t)
|
||||
|
||||
// Parse arguments
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int mode = arg_parser.get_int("mode");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
int bias_type = arg_parser.get_int("bias");
|
||||
bool store_lse = arg_parser.get_bool("lse");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int mode = arg_parser.get_int("mode");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
int bias_type = arg_parser.get_int("bias");
|
||||
[[maybe_unused]] bool store_lse = arg_parser.get_bool("lse");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
[[maybe_unused]] int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Handle default values
|
||||
if(nhead_k < 0)
|
||||
@@ -301,28 +312,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
// Warmup
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
jenga_sparse_attention(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
@@ -331,28 +342,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
jenga_sparse_attention(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
mode,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
@@ -389,8 +400,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
scale);
|
||||
|
||||
// Compare results
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
@@ -450,6 +460,22 @@ int main(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool test_result = run_test(arg_parser);
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
fmha_vsa_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
// Copy output back to host
|
||||
Y = o_buf.ToHost<DataType>();
|
||||
Y = o_buf.ToHost<DataType_>();
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user