diff --git a/CHANGELOG.md b/CHANGELOG.md index addb4ce4f0..6a9b25b062 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. +* Added FP8 KV cache support for FMHA batch prefill. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index bd63c5d78f..95e8379769 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -24,8 +24,15 @@ from codegen.cpp_symbol_map import ( ) from codegen.utils import update_file - -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8": 8, + "fp8bf16": 8, + "fp8fp32": 8, + "bf8": 8, +} K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} @@ -108,7 +115,7 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; + std::cout << ", {F_kname}" << std::flush; auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; @@ -494,6 +501,7 @@ class FmhaFwdKernel: @property def template(self) -> str: return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_kname=self.name, F_idx=self.F_idx, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], @@ -576,10 +584,14 @@ class FmhaFwdKernel: class KernelComponentFactory: @staticmethod def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": + if dtype in ["fp16", "bf16"]: return { 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip + elif dtype in ["fp8bf16"]: + return { + 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } # fmt: skip else: return None @@ -589,9 +601,9 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - qscale = "no" pipelines = [] if dtype in ["fp16", "bf16"]: + qscale = "no" for logits, mask, bias, lse, dropout in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), @@ -599,10 +611,16 @@ class KernelComponentFactory: ["t", "f"], ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip - # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip - # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + elif dtype in ["fp8bf16"]: + # no need lse/dropout kernels + for logits, qscale, mask, bias in itertools.product( + ["t", "f"], + ["pertensor"], + get_mask_map(mask_impl).keys(), + ["no"], + ): + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip else: assert False return pipelines @@ -612,7 +630,7 @@ class CustomFactory(KernelComponentFactory): @staticmethod def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": + if dtype in ["fp16", "bf16"]: if 128 in result.keys(): result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result @@ -695,15 +713,14 @@ def get_fwd_blobs( continue # Aiter(mha_batch_prefill) integration elif receipt == 200: - cond = dtype in ["fp16", "bf16"] + cond = dtype in ["fp16", "bf16", "fp8bf16"] cond &= mode == "group" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" if not cond: continue # aiter::mha_batch_prefill C++ api integration elif receipt == 600: - cond = dtype in ["fp16", "bf16"] + cond = dtype in ["fp16", "bf16", "fp8bf16"] cond &= mode == "group" cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_qscale == "no" diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 6b3464d226..dd65c0298b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1017,7 +1017,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias, sink in itertools.product( - ["f"], + ["t", "f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"], diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 6a1c620577..ba55d6d722 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -500,6 +500,9 @@ struct fmha_batch_prefill_args const void* k_ptr; const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer + const void* q_descale_ptr; + const void* k_descale_ptr; + const void* v_descale_ptr; void* rand_val_ptr; void* lse_ptr; void* o_ptr; @@ -1118,6 +1121,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, args.rand_val_ptr, args.lse_ptr, args.o_ptr, @@ -1166,6 +1172,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, args.rand_val_ptr, args.lse_ptr, args.o_ptr, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 10b5d0120e..73b6a329d1 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -36,6 +36,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; using RandValOutputDataType = ck_tile::remove_cvref_t; @@ -61,52 +62,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + - (QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr::name)); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs @@ -211,6 +166,13 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_lse = 0; }; + struct FmhaFwdCommonQScaleKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + }; + struct FmhaFwdDropoutSeedOffset { template @@ -274,8 +236,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -292,8 +257,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; ck_tile::index_t batch_stride_k; @@ -315,6 +283,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -396,6 +367,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for dropout {}, // placeholder for logits_soft_cap batch_stride_q, @@ -428,6 +400,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -463,6 +441,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -539,6 +520,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for qscale {}, // placeholder for dropout {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), @@ -568,6 +550,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -1055,37 +1043,96 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel AttentionVariant variant; const auto variant_params = [&] { + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + if constexpr(kHasLogitsSoftCap) { return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; } else { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + return ck_tile::StandardAttentionParams{mask, scale_s}; } }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, - dropout); + auto o_acc_tile = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + // TODO - move global load of descale to pipeline + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + + float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::scales{scale_o}; + }(); + + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{scale_p}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr, + kargs.kv_page_indices, + kargs.stride_k, + kargs.stride_v, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr, + kargs.kv_page_indices, + kargs.stride_k, + kargs.stride_v, + dropout); + } }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 9160e79af6..4dd99a6ea9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1499,14 +1499,28 @@ struct FmhaFwdKernel AttentionVariant variant; const auto variant_params = [&] { + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + if constexpr(kHasLogitsSoftCap) { return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; } else { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + return ck_tile::StandardAttentionParams{mask, scale_s}; } }(); @@ -1516,11 +1530,8 @@ struct FmhaFwdKernel if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline - float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); - float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); - float scale_s = kargs.scale_s * q_descale * k_descale; float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); float scale_o = v_descale / scale_p; @@ -1548,7 +1559,7 @@ struct FmhaFwdKernel o_acc_element_func, // o_acc_element_func mask, position_encoding, - scale_s, + variant_params.sm_scale, variant, variant_params, block_indices, @@ -1565,7 +1576,7 @@ struct FmhaFwdKernel lse_dram_window, mask, position_encoding, - kargs.scale_s, + variant_params.sm_scale, variant, variant_params, block_indices,