mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
[CK_TILE] naive attn (#1708)
* add reference attention fwd
* refactor addresser
* update
* paged, and i8 reflect-quant
* lets call it forward-quant
* fix error in decode variation
* update naive-attn
* fix page table
* fix build err
[ROCm/composable_kernel commit: 77a38e0211]
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ref/naive_attention.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "rotary.hpp"
|
||||
#include "utils.hpp"
|
||||
@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
@@ -447,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
bool s_randval = false;
|
||||
if(p_drop > 0.0f && do_validation)
|
||||
if(p_drop > 0.0f && do_validation != 0)
|
||||
{
|
||||
s_randval = true;
|
||||
}
|
||||
@@ -1121,11 +1122,61 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
|
||||
if(!do_validation)
|
||||
if(do_validation == 0)
|
||||
{
|
||||
std::cout << std::flush << std::endl;
|
||||
return true;
|
||||
}
|
||||
if(do_validation == 2)
|
||||
{
|
||||
// NOTE: use gpu to do validation
|
||||
ck_tile::naive_attention_fwd_traits naive_t;
|
||||
naive_t.q_type = data_type;
|
||||
naive_t.k_type = data_type;
|
||||
naive_t.v_type = data_type;
|
||||
naive_t.o_type = data_type;
|
||||
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
|
||||
naive_t.variation = 0; // TODO?
|
||||
|
||||
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::naive_attention_fwd_args naive_a;
|
||||
naive_a.q_ptr = q_buf.GetDeviceBuffer();
|
||||
naive_a.k_ptr = k_buf.GetDeviceBuffer();
|
||||
naive_a.v_ptr = v_buf.GetDeviceBuffer();
|
||||
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
|
||||
naive_a.scale_s = scale_s;
|
||||
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
|
||||
naive_a.page_table_ptr =
|
||||
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
|
||||
naive_a.hdim = hdim_q;
|
||||
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
|
||||
naive_a.batch_q = batch;
|
||||
naive_a.batch_kv = batch;
|
||||
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
|
||||
naive_a.seqlen_q = seqlen_qs[0];
|
||||
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
|
||||
naive_a.nhead_q = nhead;
|
||||
naive_a.nhead_kv = nhead_k;
|
||||
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
|
||||
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
|
||||
|
||||
ck_tile::stream_config naive_s{};
|
||||
|
||||
naive_attention_fwd(naive_t, naive_a, naive_s);
|
||||
|
||||
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
|
||||
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
||||
|
||||
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool pass_ = ck_tile::check_err(
|
||||
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
||||
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
|
||||
return pass_;
|
||||
}
|
||||
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
|
||||
Reference in New Issue
Block a user