diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 0e8eab8e6f..5ebc3fb94e 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -33,17 +33,18 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_q, int max_seqlen_k); -ck_tile::HostTensor -vsa_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, +template +ck_tile::HostTensor +vsa_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index 9ac91660d2..a5c3f9f2b3 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -15,9 +15,6 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" -// Define DataType before including the header -using DataType = ck_tile::half_t; - #include "jenga_sparse_attention.h" #include "fmha_fwd_trek.hpp" @@ -363,29 +360,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Warmup for(int i = 0; i < warmup; ++i) { - vsa_sparse_attention(q_host, - k_host, - v_host, - lut_host, - valid_block_num_host, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + vsa_sparse_attention(q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } // Benchmark @@ -394,29 +391,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - vsa_sparse_attention(q_host, - k_host, - v_host, - lut_host, - valid_block_num_host, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + vsa_sparse_attention(q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); @@ -526,10 +523,7 @@ int main(int argc, char* argv[]) } else if(prec == "bf16") { - std::cout << "Note: Using bf16 precision" << std::endl; - // For bf16, we would need to compile with DataType = ck_tile::bf16_t - // For now, run with the compiled DataType - test_result = run_test(arg_parser); + test_result = run_test(arg_parser); } else { diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index d75b5bae65..a824e9389f 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -6,18 +6,20 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/device_memory.hpp" +#include -ck_tile::HostTensor -vsa_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t - ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, +template +ck_tile::HostTensor +vsa_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& TKV_block_idx, + ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -32,8 +34,12 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_q, int max_seqlen_k) { + // Determine data type string based on template parameter std::string data_type = "fp16"; - // DataType is determined at compile time via template + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; @@ -218,3 +224,26 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, return Y; } + +// Explicit template instantiations +template ck_tile::HostTensor +vsa_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int); + +template ck_tile::HostTensor +vsa_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int);