From 60e814a3ba8ee804c79268c0bca0248af6404123 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 3 Jan 2025 18:43:07 +0800 Subject: [PATCH] [CK_TILE]naive attn support FP8 KVCache quant (#1747) * quant * fix bug * simple smoothquant after softmax * update kv-quant * update stride * fix fp8-pertoken-kvcache * update int8/fp8 quant support --------- Co-authored-by: so Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: 6df5fe2ad8fb6ff054a3e75250ccef7c878c3455] --- example/ck_tile/01_fmha/fmha_fwd.cpp | 19 +- include/ck_tile/ref/naive_attention.hpp | 422 ++++++++++++++++-------- 2 files changed, 301 insertions(+), 140 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 08d263da91..b3855e59df 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser) { // NOTE: use gpu to do validation ck_tile::naive_attention_fwd_traits naive_t; - naive_t.q_type = data_type; - naive_t.k_type = data_type; - naive_t.v_type = data_type; - naive_t.o_type = data_type; - naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; - naive_t.variation = 0; // TODO? + naive_t.q_type = data_type; + naive_t.k_type = data_type; + naive_t.v_type = data_type; + naive_t.o_type = data_type; + naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; + naive_t.variation = 0; // TODO? + naive_t.quant_algo = 0; ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); diff --git a/include/ck_tile/ref/naive_attention.hpp b/include/ck_tile/ref/naive_attention.hpp index 09ded761eb..98ceab6992 100644 --- a/include/ck_tile/ref/naive_attention.hpp +++ b/include/ck_tile/ref/naive_attention.hpp @@ -13,13 +13,18 @@ namespace ck_tile { enum class naive_attention_layout_enum { - BSHD, // [batch, seqlen, nhead, hdim] - BHSD, // [batch, nhead, seqlen, hdim] - BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed - PHSD, // [pages, nhead, page_size, hdim] + DEFAULT, // maybe this tensor is not used, set some irrelevant value + BSHD, // [batch, seqlen, nhead, hdim] + BHSD, // [batch, nhead, seqlen, hdim] + BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed + PHSD, // [pages, nhead, page_size, hdim] // PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen + + // scale layout used for dynamic dequant + SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant + SCALE_SH, // [tokens, nhead] }; // will used to specialize kernel variation @@ -30,6 +35,15 @@ enum class naive_attention_variation_enum DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache }; +enum class naive_attention_quant_algo +{ + NO = 0, + KV_8BIT_PERHEAD = 1, + // FP8/INT8 quant for KVCache, per-token quant + // [num_tokens, nhead, hdim] -> [nhead, num_tokens] + KV_8BIT_PERTOKEN = 2, +}; + // TODO: for simplicity, this will be used as host/device arg struct naive_attention_fwd_args { @@ -40,7 +54,8 @@ struct naive_attention_fwd_args void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a // number, not cumsum) void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn) - void* kvscale_ptr; // [nhead, 2(kv), hdim] used for kvcache dequant + void* kscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant + void* vscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant float scale_s; int hdim; int hdim_v; // could be cross-attn, where V and Q/K hdim are different @@ -54,6 +69,7 @@ struct naive_attention_fwd_args int nhead_ratio_kv; // nhead_q / nhead_kv int page_size; // if paged, the seqlen-kv per each block int max_pages_per_seq; + int max_kv_tokens; // used as stride to access kv scale ptr }; // this is trait for host API @@ -67,14 +83,16 @@ struct naive_attention_fwd_traits std::string k_layout; std::string v_layout; std::string o_layout; - int variation; // sync with naive_attention_variation_enum + int variation; // sync with naive_attention_variation_enum + int quant_algo; // sync with naive_attention_quant_algo }; // this is trait for kernel template -template +template struct naive_attention_fwd_kernel_traits { static constexpr naive_attention_variation_enum variation = variation_; + static constexpr naive_attention_quant_algo quant_algo = quant_algo_; }; // for simplicity, please do not use const-reference type for the template type @@ -83,28 +101,39 @@ template struct naive_attention_fwd_kernel { static constexpr bool is_kvcache_i8 = - std::is_same_v && std::is_same_v && sizeof(QType) != 1; + std::is_same_v && std::is_same_v; + static constexpr bool is_kvcache_fp8 = + std::is_same_v && std::is_same_v; - // kvcache-i8 will have per head scale, we apply this scale to Q/P matrix instead of original - // K/V matrix. This can speed up conversion since Q/P usually is fp16/bf16/fp32 - static constexpr bool is_kvcache_i8_forward_quant = is_kvcache_i8; + static constexpr int v_per_token_quant_group_size = 64; // TODO: hardcode - using KVScaleType = float; - using SoftmaxType = float; - using PType = VType; // src A of gemm2, same type as V + using SoftmaxType = float; // always using float to do softmax compute + using QuantComputeType = float; // used for quant/dequant scale compute + using QCompute = KType; // src A of gemm1, same type as K + using PType = VType; // src A of gemm2, same type as V + using OAccType = float; // always float, in case int8 FA using p_vec_type = ext_vector_t; static constexpr int p_vec_elem = vector_traits::vector_size; + // clang-format off + template struct scale_max { static constexpr float value = 1; /* dummy code */ }; + template <> struct scale_max { static constexpr float value = 127.0; }; + template <> struct scale_max { static constexpr float value = 240.0; }; + // clang-format on + __host__ __device__ naive_attention_fwd_kernel() {} template @@ -198,24 +227,31 @@ struct naive_attention_fwd_kernel __device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {} }; - template + template struct kvscale_addresser { - int h, d; // nhead, hdim + int s, h, d; // seqlen(tokens), nhead, hdim T* base_ptr; - __device__ kvscale_addresser(int h_, int d_, void* p_) - : h(h_), d(d_), base_ptr(reinterpret_cast(p_)) + __device__ kvscale_addresser(int s_, int h_, int d_, void* p_) + : s(s_), h(h_), d(d_), base_ptr(reinterpret_cast(p_)) { } - __device__ int get_offset(int i_h, int i_d, int i_kv /*0 or 1*/) + __device__ int get_offset(int i_s, int i_h, int i_d) { + if constexpr(Layout == naive_attention_layout_enum::SCALE_HS) + { + // [nhead, tokens] + (void)i_d; + return i_h * s + i_s; + } + else if constexpr(Layout == naive_attention_layout_enum::DEFAULT) + { + return 0; + } // [h, 2, d] - return i_h * 2 * d + i_kv * d + i_d; - } - __device__ T load(int i_h, int i_d, int i_kv) - { - return base_ptr[get_offset(i_h, i_d, i_kv)]; + // return i_h * 2 * d + i_kv * d + i_d; } + __device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; } }; __device__ __host__ static constexpr int get_block_size() { return 256; } @@ -282,12 +318,13 @@ struct naive_attention_fwd_kernel __device__ void operator()(naive_attention_fwd_args args) { constexpr int wg_size = get_block_size(); - __shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough - int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v - int i_sq = blockIdx.y; // index of seqlen_q - int i_batch = blockIdx.z; // index of batch_q * nhead_q - int i_bq = i_batch / args.nhead_q; // index of batch_q - int i_hq = i_batch % args.nhead_q; // index of nhead_q + __shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough + char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough + int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v + int i_sq = blockIdx.y; // index of seqlen_q + int i_batch = blockIdx.z; // index of batch_q * nhead_q + int i_bq = i_batch / args.nhead_q; // index of batch_q + int i_hq = i_batch % args.nhead_q; // index of nhead_q int i_bk = i_bq / args.batch_ratio_kv; int i_hk = i_hq / args.nhead_ratio_kv; @@ -360,9 +397,10 @@ struct naive_attention_fwd_kernel auto f_max = [](auto x_, auto y_) { return max(x_, y_); }; auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; auto f_absmax_f32 = [](float v_0_, float v_1_) { - float rtn; - asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_)); - return rtn; + // float rtn; + // asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_)); + // return rtn; + return max(abs(v_0_), abs(v_1_)); }; int seqlen_kv = [&]() { @@ -378,45 +416,82 @@ struct naive_attention_fwd_kernel SoftmaxType row_max = -numeric::infinity(); SoftmaxType l{0}; - AccType o_acc = {0}; + // AccType o_acc = {0}; + OAccType o_acc = {0}; - int sk_loops = (seqlen_kv + wg_size - 1) / wg_size; - float qf_scale = .0f; - kvscale_addresser kvscale_addr{args.nhead_kv, args.hdim, args.kvscale_ptr}; + int sk_loops = (seqlen_kv + wg_size - 1) / wg_size; + QuantComputeType q_dequant_scale = .0f; + kvscale_addresser kscale_addr{ + args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr}; + kvscale_addresser vscale_addr{ + args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr}; - if constexpr(is_kvcache_i8_forward_quant) + if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD) { // AccType is i32 now, seqlen_q = 1, hdim up to 256 - float q = 0; - float k_s = 0; + AccType q = 0; + AccType k_s = 0; if(static_cast(threadIdx.x) < args.hdim) { - q = type_convert(q_addr.load(0, threadIdx.x)); - k_s = type_convert(kvscale_addr.load(i_hk, threadIdx.x, 0)); + q = type_convert(q_addr.load(0, threadIdx.x)); + k_s = type_convert(kscale_addr.load(i_hk, threadIdx.x, 0)); } // 1) we apply the k scale to q - float q_forwarded = q * k_s; + AccType q_forwarded = q * k_s; // 2) apply smooth-quant // find absmax - float qf_max = wave_reduce(q_forwarded, f_absmax_f32); - qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast(smem)); + AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32); + qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast(smem)); // per-token scale - qf_scale = qf_max / 127.0; + q_dequant_scale = type_convert(qf_max) / scale_max::value; // devide by scale - q = q / qf_scale; + q = q / q_dequant_scale; // fp32->i8 - int8_t quantized_q = static_cast(q); + QCompute quantized_q = static_cast(q); __syncthreads(); - reinterpret_cast(smem)[threadIdx.x] = quantized_q; + reinterpret_cast(smem)[threadIdx.x] = quantized_q; __syncthreads(); // after above process, we have 2 data // 1) int8 q data stored in smem(no need to reload) - // 2) per-token scale qf_scale, to be mul after 1st gemm + // 2) per-token scale q_dequant_scale, to be mul after 1st gemm + } + else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN) + { + if(std::is_same_v || std::is_same_v) + { + // dyanmic quant q here + float q = 0; + if(static_cast(threadIdx.x) < args.hdim) + { + q = type_convert(q_addr.load(i_sq, threadIdx.x)); + } + + // apply smooth-quant + // find absmax + float q_max = wave_reduce(q, f_absmax_f32); + q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast(smem)); + + // per-token scale + q_dequant_scale = + type_convert(q_max) / scale_max::value; + + // devide by scale + q = q / q_dequant_scale; + + QCompute quantized_q = type_convert(q); + __syncthreads(); + reinterpret_cast(smem_quant_q)[threadIdx.x] = quantized_q; + __syncthreads(); + + // after above process, we have 2 data + // 1) fp8 q data stored in smem(no need to reload from global) + // 2) per-token scale q_dequant_scale, to be mul after 1st gemm + } } for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++) @@ -429,33 +504,41 @@ struct naive_attention_fwd_kernel AccType s_acc{0}; // clear for every loop for(auto i_dq = 0; i_dq < args.hdim; i_dq++) { - if constexpr(is_kvcache_i8_forward_quant) - { - int8_t q = reinterpret_cast(smem)[i_dq]; - auto k = k_addr.load(i_sk, i_dq); + auto q = [&]() { + if constexpr(Traits::quant_algo == + naive_attention_quant_algo::KV_8BIT_PERHEAD || + Traits::quant_algo == + naive_attention_quant_algo::KV_8BIT_PERTOKEN) + { + return reinterpret_cast(smem_quant_q)[i_dq]; + } + else + return q_addr.load(i_sq, i_dq); // q will have duplicate load + }(); + auto k = [&]() { return k_addr.load(i_sk, i_dq); }(); - s_acc += type_convert(q) * type_convert(k); - } - else - { - auto q = q_addr.load(i_sq, i_dq); // q will have duplicate load - auto k = k_addr.load(i_sk, i_dq); - - s_acc += type_convert(q) * type_convert(k); - } + s_acc += type_convert(q) * type_convert(k); } // scale s_softmax = type_convert(s_acc); s_softmax *= type_convert(args.scale_s * ck_tile::log2e_v); - if constexpr(is_kvcache_i8_forward_quant) + if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD) { - s_softmax *= qf_scale; // post scale the per-token factor + s_softmax *= q_dequant_scale; // post scale the per-token factor + } + else if constexpr(Traits::quant_algo == + naive_attention_quant_algo::KV_8BIT_PERTOKEN) + { + SoftmaxType k_per_token_scale = + type_convert(kscale_addr.load(i_sk, i_hk, 0)); + s_softmax *= q_dequant_scale; + s_softmax *= k_per_token_scale; } } // s->p - float pf_scale = 0.; // used for i8 quant + QuantComputeType p_dequant_scale = 1.; { // softmax, find max SoftmaxType old_max = row_max; @@ -473,41 +556,69 @@ struct naive_attention_fwd_kernel // l, pre-scall o_acc SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max); l = tmp * l + row_sum; - o_acc = type_convert(type_convert(o_acc) * tmp); + o_acc = type_convert(type_convert(o_acc) * tmp); // prepare the p_compute into smem, to let every thread read same p_compute and do // 2nd gemm - if constexpr(is_kvcache_i8_forward_quant) + if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD) { - float v_s = 0; + QuantComputeType v_s = 0; if(static_cast(threadIdx.x) < args.hdim_v) { - v_s = type_convert(kvscale_addr.load(i_hk, threadIdx.x, 1)); + v_s = + type_convert(vscale_addr.load(i_hk, threadIdx.x, 1)); } // 1) we apply the v scale to p - float p_forwarded = p_compute * v_s; + QuantComputeType p_forwarded = p_compute * v_s; // 2) apply smooth-quant // find absmax - float pf_max = wave_reduce(p_forwarded, f_absmax_f32); - pf_max = - cross_wave_reduce(pf_max, f_absmax_f32, reinterpret_cast(smem)); + QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32); + pf_max = cross_wave_reduce( + pf_max, f_absmax_f32, reinterpret_cast(smem)); // per-token scale - pf_scale = pf_max / 127.0; + p_dequant_scale = pf_max / scale_max::value; // 127.0; // devide by scale - p_compute = p_compute / pf_scale; + p_compute = p_compute / p_dequant_scale; // fp32->i8 - int8_t quantized_p = static_cast(p_compute); + PType quantized_p = static_cast(p_compute); __syncthreads(); - reinterpret_cast(smem)[threadIdx.x] = quantized_p; + reinterpret_cast(smem)[threadIdx.x] = quantized_p; __syncthreads(); // after above process, we have 2 data // 1) int8 p data stored in smem(no need to reload) - // 2) per-token scale pf_scale, to be mul after 2nd gemm + // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm + } + else if constexpr(Traits::quant_algo == + naive_attention_quant_algo::KV_8BIT_PERTOKEN) + { + // forward apply the v scale to p_compute, this is compute friendly + auto v_scale = type_convert(vscale_addr.load(i_sk, i_hk, 0)); + p_compute *= v_scale; + // smooth-quant + // find absmax + QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32); + p_max = cross_wave_reduce( + p_max, f_absmax_f32, reinterpret_cast(smem)); + + // per-token scale + p_dequant_scale = p_max / scale_max::value; // 240.0; + + // devide by scale + p_compute = p_compute / p_dequant_scale; + + // fp32->i8 + PType quantized_p = type_convert(p_compute); + __syncthreads(); + reinterpret_cast(smem)[threadIdx.x] = quantized_p; + __syncthreads(); + // after above process, we have 2 data + // 1) fp8_t p data stored in smem(no need to reload) + // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm } else { @@ -531,29 +642,45 @@ struct naive_attention_fwd_kernel int sv_offset = i_loop2 * p_vec_elem + i_j; int i_sv = sk_start + sv_offset; - VType v = 0.f; + VType v = 0; if(i_dv < args.hdim_v && i_sv < seqlen_kv) { v = v_addr.load(i_sv, i_dv); } - o_acc_local += type_convert(p_vec[i_j]) * type_convert(v); + AccType v_compute = [&]() { return type_convert(v); }(); + + o_acc_local += type_convert(p_vec[i_j]) * v_compute; } } - if constexpr(is_kvcache_i8_forward_quant) - { - // apply pr scale to local acc - o_acc_local = - type_convert(type_convert(o_acc_local) * pf_scale); - } - o_acc += o_acc_local; + + OAccType post_scale_o_acc_local = [&]() { + if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD) + { + // apply pr scale to local acc + return type_convert(type_convert(o_acc_local) * + p_dequant_scale); + } + else if constexpr(Traits::quant_algo == + naive_attention_quant_algo::KV_8BIT_PERTOKEN) + { + // apply pr scale to local acc + return type_convert(type_convert(o_acc_local) * + p_dequant_scale); + } + else + { + return type_convert(o_acc_local); + } + }(); + o_acc += post_scale_o_acc_local; } } // post scale o_acc { SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking - o_acc = type_convert(type_convert(o_acc) * tmp); + o_acc = type_convert(type_convert(o_acc) * tmp); } // store O @@ -564,18 +691,21 @@ struct naive_attention_fwd_kernel #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \ { \ - using ktraits_ = \ - naive_attention_fwd_kernel_traits( \ - variation_)>; \ + using ktraits_ = naive_attention_fwd_kernel_traits< \ + static_cast(variation_), \ + static_cast(quant_algo_)>; \ using k_ = naive_attention_fwd_kernel; \ dim3 grids = k_::get_grid_size(a); \ r = ck_tile::launch_kernel(s, \ @@ -586,31 +716,37 @@ struct naive_attention_fwd_kernel if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \ t.o_layout == "bshd") \ { \ - constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \ - constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \ - constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \ - constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \ - constexpr int variation_ = 0; \ + constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \ + constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \ + constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \ + constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \ + constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \ + constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \ + constexpr int variation_ = 0; \ CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \ } \ else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \ t.v_layout == "bhsd" && t.o_layout == "bhsd") \ { \ - constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \ - constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \ - constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \ - constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \ - constexpr int variation_ = 0; \ + constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \ + constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \ + constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \ + constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \ + constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \ + constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \ + constexpr int variation_ = 0; \ CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \ } \ else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \ t.v_layout == "phds" && t.o_layout == "bhsd") \ { \ - constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \ - constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \ - constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \ - constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \ - constexpr int variation_ = 2; \ + constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \ + constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \ + constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \ + constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \ + constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \ + constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \ + constexpr int variation_ = 2; \ CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \ } @@ -621,40 +757,64 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, { float r = -1; // TODO: do not explicitly create too much instance! - if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16") + if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" && + t.quant_algo == 0) { - using q_type_ = fp16_t; - using k_type_ = fp16_t; - using v_type_ = fp16_t; - using o_type_ = fp16_t; - using acc_type_ = float; + using q_type_ = fp16_t; + using k_type_ = fp16_t; + using v_type_ = fp16_t; + using o_type_ = fp16_t; + using acc_type_ = float; + using kvscale_type_ = float; + constexpr int quant_algo_ = 0; CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); } - else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16") + else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" && + t.quant_algo == 0) { - using q_type_ = bf16_t; - using k_type_ = bf16_t; - using v_type_ = bf16_t; - using o_type_ = bf16_t; - using acc_type_ = float; + using q_type_ = bf16_t; + using k_type_ = bf16_t; + using v_type_ = bf16_t; + using o_type_ = bf16_t; + using acc_type_ = float; + using kvscale_type_ = float; + constexpr int quant_algo_ = 0; CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); } - else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16") + else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" && + t.quant_algo == 2) { - using q_type_ = bf16_t; - using k_type_ = int8_t; - using v_type_ = int8_t; - using o_type_ = bf16_t; - using acc_type_ = int32_t; // NOTE! + using q_type_ = bf16_t; + using k_type_ = fp8_t; + using v_type_ = fp8_t; + using o_type_ = bf16_t; + using acc_type_ = float; // NOTE! + using kvscale_type_ = float; + constexpr int quant_algo_ = 2; CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); } - else if(t.q_type == "fp16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "fp16") + else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" && + t.quant_algo == 2) { - using q_type_ = fp16_t; - using k_type_ = int8_t; - using v_type_ = int8_t; - using o_type_ = fp16_t; - using acc_type_ = int32_t; // NOTE! + using q_type_ = fp16_t; + using k_type_ = fp8_t; + using v_type_ = fp8_t; + using o_type_ = fp16_t; + using acc_type_ = float; // NOTE! + using kvscale_type_ = float; + constexpr int quant_algo_ = 2; + CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); + } + else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" && + t.quant_algo == 2) + { + using q_type_ = bf16_t; + using k_type_ = int8_t; + using v_type_ = int8_t; + using o_type_ = bf16_t; + using acc_type_ = int32_t; // NOTE! + using kvscale_type_ = float; + constexpr int quant_algo_ = 2; CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_(); } return r;