From 2312eef6c36d2811b1f57c85c8ae4c58a595be9e Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Wed, 11 Mar 2026 10:00:52 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc) [CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- CHANGELOG.md | 1 + .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 4 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 46 +- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 22 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 67 +++ example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 505 +++++++++++++----- example/ck_tile/01_fmha/quant.hpp | 7 + .../core/arch/amd_buffer_addressing.hpp | 3 +- .../arch/amd_buffer_addressing_builtins.hpp | 3 +- include/ck_tile/host.hpp | 1 + .../reference_batched_mx_descale.hpp | 61 +++ include/ck_tile/ops/fmha.hpp | 1 + .../block_attention_quant_scale_enum.hpp | 11 + .../ck_tile/ops/fmha/block/cast_tile_mx.hpp | 186 +++++++ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 299 ++++++++++- .../pipeline/block_fmha_pipeline_problem.hpp | 9 + .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 265 +++++++-- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 13 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 338 +++++++++--- include/ck_tile/ops/gemm.hpp | 2 + .../block_gemm_mx_areg_bsmem_creg_v1.hpp | 374 +++++++++++++ ...mm_mx_areg_bsmem_creg_v1_custom_policy.hpp | 36 ++ include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 36 ++ .../gemm/warp/warp_gemm_attribute_mfma.hpp | 26 + .../warp/warp_gemm_attribute_mfma_impl.hpp | 121 +++-- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 10 +- test/CMakeLists.txt | 4 +- test/ck_tile/fmha/CMakeLists.txt | 10 +- test/ck_tile/fmha/test_fmha_fwd.cpp | 62 ++- 29 files changed, 2167 insertions(+), 356 deletions(-) create mode 100644 include/ck_tile/host/reference/reference_batched_mx_descale.hpp create mode 100644 include/ck_tile/ops/fmha/block/cast_tile_mx.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index e1feebcb3e..370e9e4243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. * Added FP8 block scale quantization for FMHA forward kernel. * Added gfx11 support for FMHA. +* Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only). ### Changed diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 995fc8c965..e9ae11fb5f 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -9,6 +9,8 @@ FWD_DTYPE_MAP = { "fp8fp16": "FmhaFwdFp8Fp16", "fp8bf16": "FmhaFwdFp8Bf16", "fp8fp32": "FmhaFwdFp8Fp32", + "mxfp8": "FmhaFwdMxFp8", + "mxfp4": "FmhaFwdMxFp4", } BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"} @@ -79,6 +81,7 @@ QSCALE_MAP = { "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", "kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE", + "mx": "ck_tile::BlockAttentionQuantScaleEnum::MX", } QSCALE_CHECK_MAP = { @@ -86,6 +89,7 @@ QSCALE_CHECK_MAP = { "pertensor": "quant_scale_enum::pertensor", "blockscale": "quant_scale_enum::blockscale", "kv_blockscale": "quant_scale_enum::kv_blockscale", + "mx": "quant_scale_enum::mx", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 627352e226..18e0022cf5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -38,6 +38,8 @@ DTYPE_BITS = { "fp8bf16": 8, "fp8fp32": 8, "bf8": 8, + "mxfp8": 8, + "mxfp4": 4, } K0_MAX_SUBMAX_MAP = { @@ -836,7 +838,8 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): def check_hdim_tile( problem_ctx: ProblemContext, kernel_ctx: KernelContext ) -> bool: - if problem_ctx.dtype != "fp32": + # FIX: too confusing that it has to know about mx types + if problem_ctx.dtype not in ("fp32", "mxfp8", "mxfp4"): # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and ( ( @@ -966,8 +969,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): return { (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip - else: - raise ValueError(f"unsupported dtype={dtype}") # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future @@ -1035,9 +1036,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): else: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip - elif dtype in ["fp8", "fp8fp16", "bf8"]: - # TODO - pass return pipelines @@ -1046,6 +1044,17 @@ class KernelComponentFactoryGfx950( ): arch = ArchTrait("gfx950") + _DT_MXFP8 = ("mxfp8",) + _DT_MXFP4 = ("mxfp4",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return ( + KernelComponentFactoryGfx9.supported_dtypes() + + cls._DT_MXFP8 + + cls._DT_MXFP4 + ) + @classmethod def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) @@ -1054,6 +1063,18 @@ class KernelComponentFactoryGfx950( if (128, 128) in result.keys(): result[(128, 128)].append( FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + elif dtype in cls._DT_MXFP8: + return { + # bm0, bn0, bk0, bn1, bk1, + (128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)], + (256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)], + } # fmt: skip + elif dtype in cls._DT_MXFP4: + return { + # bm0, bn0, bk0, bn1, bk1, + (128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)], + (256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)], + } # fmt: skip return result @classmethod @@ -1091,6 +1112,19 @@ class KernelComponentFactoryGfx950( pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4: + # no need dropout kernels + lse = "t" + dropout = "f" + for logits, qscale, mask, bias, sink in itertools.product( + ["f"], + ["mx"], + get_mask_map(mask_impl).keys(), + ["no"], + ["f", "t"], + ): + pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index f5ad6b2bc5..122d232a1c 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -48,8 +48,12 @@ auto create_args(int argc, char* argv[]) .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") .insert("qscale", "n", - "n or 0, no scale\n" - "pt or 1, per-tensor scale\n") + "quant scale:\n" + " n or 0, no scale\n" + " pt or 1, per-tensor scale\n" + " bs or 2, block scale\n" + " kvbs or 3, Q per-tensor, K/V per-page block scale\n" + " mx or 4, microscaling (exclusively for data types like mxfp8 and mxfp4)") .insert("logits_soft_cap", "0", "attention logits soft capping value.") .insert("iperm", "1", @@ -61,7 +65,7 @@ auto create_args(int argc, char* argv[]) "n or 0, no bias\n" "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8") + .insert("prec", "fp16", "data type: fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -231,6 +235,10 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } + else if(data_type == "fp8") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } else if(data_type == "fp8bf16") { return run(arg_parser) == fwd_result::success ? 0 : -2; @@ -239,6 +247,14 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } + else if(data_type == "mxfp8") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "mxfp4") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } std::cerr << "Unsupported precision: " << data_type << std::endl; return -1; } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 3123e2bd59..4adb159b31 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -50,6 +50,14 @@ struct FmhaFwdFp8Fp32 { }; +struct FmhaFwdMxFp8 +{ +}; + +struct FmhaFwdMxFp4 +{ +}; + template struct FmhaFwdTypeConfig; @@ -165,6 +173,54 @@ struct FmhaFwdTypeConfig using ODataType = float; }; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; + + using QScaleDataType = ck_tile::e8m0_t; + using KScaleDataType = ck_tile::e8m0_t; + using VScaleDataType = ck_tile::e8m0_t; + using PScaleDataType = ck_tile::e8m0_t; + + static constexpr ck_tile::index_t kQKScaleGranularity = 32; + static constexpr ck_tile::index_t kVScaleGranularity = 32; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::pk_fp4_t; + using KDataType = ck_tile::pk_fp4_t; + using VDataType = ck_tile::pk_fp4_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::pk_fp4_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; + + using QScaleDataType = ck_tile::e8m0_t; + using KScaleDataType = ck_tile::e8m0_t; + using VScaleDataType = ck_tile::e8m0_t; + using PScaleDataType = ck_tile::e8m0_t; + + static constexpr ck_tile::index_t kQKScaleGranularity = 32; + static constexpr ck_tile::index_t kVScaleGranularity = 32; +}; + struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; @@ -232,6 +288,7 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* block_scale_seqstart_q_ptr; const void* block_scale_seqstart_k_ptr; + const void* seqstart_v_scale_ptr; const void* sink_ptr; ck_tile::index_t seqlen_q; @@ -252,6 +309,9 @@ struct fmha_fwd_args ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_randval; ck_tile::index_t stride_o; + ck_tile::index_t stride_q_descale; + ck_tile::index_t stride_k_descale; + ck_tile::index_t stride_v_descale; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; @@ -635,6 +695,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqlen_k_ptr, args.block_scale_seqstart_q_ptr, args.block_scale_seqstart_k_ptr, + args.seqstart_v_scale_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -647,6 +708,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_bias, args.stride_randval, args.stride_o, + args.stride_q_descale, + args.stride_k_descale, + args.stride_v_descale, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, @@ -697,6 +761,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_bias, args.stride_randval, args.stride_o, + args.stride_q_descale, + args.stride_k_descale, + args.stride_v_descale, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 1227724d40..17d53a4e6d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -84,6 +84,22 @@ auto get_elimit(std::string /*init_method*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-1; + double atol = 2.6e-1; + return ck_tile::make_tuple(rtol, atol); +} + int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split @@ -171,6 +187,28 @@ void copy_attention_scores_with_sink(const ck_tile::HostTensor +struct ScalesConfig +{ + using QScaleDataType = float; + using KScaleDataType = float; + using VScaleDataType = float; + + static constexpr ck_tile::index_t kQKScaleGranularity = 1; + static constexpr ck_tile::index_t kVScaleGranularity = 1; +}; + +template +struct ScalesConfig +{ + using QScaleDataType = typename TypeConfig::QScaleDataType; + using KScaleDataType = typename TypeConfig::KScaleDataType; + using VScaleDataType = typename TypeConfig::VScaleDataType; + + static constexpr ck_tile::index_t kQKScaleGranularity = TypeConfig::kQKScaleGranularity; + static constexpr ck_tile::index_t kVScaleGranularity = TypeConfig::kVScaleGranularity; +}; + template fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t batch, @@ -210,6 +248,31 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { + using TypeConfig = FmhaFwdTypeConfig; + + constexpr bool is_mx = ck_tile::is_any_of::value; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = std::conditional_t; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + using QScaleDataType = typename ScalesConfig::QScaleDataType; + using KScaleDataType = typename ScalesConfig::KScaleDataType; + using VScaleDataType = typename ScalesConfig::VScaleDataType; + + constexpr ck_tile::index_t kQKScaleGranularity = + ScalesConfig::kQKScaleGranularity; + constexpr ck_tile::index_t kVScaleGranularity = + ScalesConfig::kVScaleGranularity; + // Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the // compute block size constexpr ck_tile::index_t block_scale_size_q_ = 128; @@ -230,6 +293,10 @@ fwd_result fmha_fwd_run(mode_enum mode, return "fp8bf16"; else if constexpr(std::is_same_v) return "fp8fp32"; + else if constexpr(std::is_same_v) + return "mxfp8"; + else if constexpr(std::is_same_v) + return "mxfp4"; else static_assert(false); }(); @@ -242,6 +309,26 @@ fwd_result fmha_fwd_run(mode_enum mode, return fwd_result::invalid_args; } + if(hdim_q % ck_tile::numeric_traits::PackedSize != 0) + { + std::cerr << "hdim_q is made even for fp4 Q data type" << std::endl; + hdim_q = + ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits::PackedSize); + } + if(hdim_q % ck_tile::numeric_traits::PackedSize != 0) + { + std::cerr << "hdim_q is made even for fp4 K data type" << std::endl; + hdim_q = + ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits::PackedSize); + } + if(is_mx && !seqlen_kpads.empty() && seqlen_kpads[0] > 0) + { + std::cerr + << "seqlen_kpads is not supported with MX types. ignoring the 'seqlen_kpads' option" + << std::endl; + seqlen_kpads = {-1}; + } + std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}()); auto next_seed = [&random_engine]() { return static_cast(random_engine()); }; @@ -371,6 +458,20 @@ fwd_result fmha_fwd_run(mode_enum mode, /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, need_append_kvcache, random_engine); + + if(ck_tile::numeric_traits::PackedSize != 0) + { + // Ensure that all seqlens are even if V has packed data type + for(auto& s : seqlen_ks) + { + s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits::PackedSize); + } + for(auto& s : kv_eff_lens_per_batch) + { + s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits::PackedSize); + } + } + for(ck_tile::index_t wb = 0; wb < batch; ++wb) { if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb]) @@ -410,6 +511,22 @@ fwd_result fmha_fwd_run(mode_enum mode, quant_scale_info qscale = quant_scale_info::decode(qscale_str); + if(is_mx && qscale.type != quant_scale_enum::mx) + { + std::cerr << "The value of qscale_str must be 'mx' for MX data types" << std::endl; + return fwd_result::invalid_args; + } + else if(!is_mx && qscale.type == quant_scale_enum::mx) + { + std::cerr << "The value of qscale_str cannot be 'mx' for non-MX data types" << std::endl; + return fwd_result::invalid_args; + } + if(is_mx && is_v_rowmajor) + { + std::cerr << "The value of is_v_rowmajor must be 'false' for MX data types" << std::endl; + return fwd_result::invalid_args; + } + if(p_drop < 0.0f || p_drop > 1.0f) { std::cerr << "The value of p_drop should be 0~1" << std::endl; @@ -458,28 +575,16 @@ fwd_result fmha_fwd_run(mode_enum mode, calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); } - using TypeConfig = FmhaFwdTypeConfig; - - using QDataType = typename TypeConfig::QDataType; - using KDataType = typename TypeConfig::KDataType; - using VDataType = typename TypeConfig::VDataType; - using BiasDataType = typename TypeConfig::BiasDataType; - using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; - using LSEDataType = typename TypeConfig::LSEDataType; - using SaccDataType = typename TypeConfig::SaccDataType; - using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; - using PDataType = typename TypeConfig::PDataType; - using OaccDataType = typename TypeConfig::OaccDataType; - using ODataType = typename TypeConfig::ODataType; - // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - size_t i_block_scale_q = 0; - size_t i_block_scale_k = 0; + int32_t i_block_scale_q = 0; + int32_t i_block_scale_k = 0; + int32_t i_seqstart_v_scale = 0; std::vector block_scale_seqstart_q_host = {0}; std::vector block_scale_seqstart_k_host = {0}; + std::vector seqstart_v_scale_host = {0}; auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) @@ -496,18 +601,31 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } - i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); - i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); - block_scale_seqstart_q_host.push_back(i_block_scale_q); - block_scale_seqstart_k_host.push_back(i_block_scale_k); + if(qscale.type == quant_scale_enum::blockscale) + { + i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); + i_block_scale_k += + ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); + block_scale_seqstart_q_host.push_back(i_block_scale_q); + block_scale_seqstart_k_host.push_back(i_block_scale_k); + } + else if(qscale.type == quant_scale_enum::mx) + { + i_seqstart_v_scale += + ck_tile::integer_divide_ceil(real_seqlen_k, kVScaleGranularity); + seqstart_v_scale_host.push_back(i_seqstart_v_scale); + } flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); - num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q / + ck_tile::numeric_traits::PackedSize + sizeof(ODataType) * real_seqlen_q * hdim_v); - num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + - sizeof(VDataType) * hdim_v * real_seqlen_k); + num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q / + ck_tile::numeric_traits::PackedSize + + sizeof(VDataType) * hdim_v * real_seqlen_k / + ck_tile::numeric_traits::PackedSize); } } @@ -620,19 +738,30 @@ fwd_result fmha_fwd_run(mode_enum mode, hdim_v} : std::array{1, 1, 1, 1, 1}); - // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead, num_block_scale_q} - : std::array{1, 1, 1}); - ck_tile::HostTensor k_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_kv} - : std::array{1, 1, 1}); - ck_tile::HostTensor v_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_kv} - : std::array{1, 1, 1}); + const ck_tile::index_t hdim_q_scale = ck_tile::integer_divide_ceil(hdim_q, kQKScaleGranularity); + const ck_tile::index_t shape_seqlen_v_scale = seqstart_v_scale_host.back(); + + ck_tile::HostTensor q_descale_host({1}); + ck_tile::HostTensor k_descale_host({1}); + ck_tile::HostTensor v_descale_host({1}); + if constexpr(is_mx) + { + q_descale_host = ck_tile::HostTensor( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q_scale)); + k_descale_host = ck_tile::HostTensor( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q_scale)); + v_descale_host = ck_tile::HostTensor( + get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_v_scale)); + } + else if(qscale.type == quant_scale_enum::blockscale) + { + q_descale_host = ck_tile::HostTensor( + std::array{shape_batch, nhead, num_block_scale_q}); + k_descale_host = ck_tile::HostTensor( + std::array{shape_batch, nhead_k, num_block_scale_kv}); + v_descale_host = ck_tile::HostTensor( + std::array{shape_batch, nhead_k, num_block_scale_kv}); + } // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -704,10 +833,9 @@ fwd_result fmha_fwd_run(mode_enum mode, } else if(init_method == "3") { - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float bias_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, next_seed()}(q_host); ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, next_seed()}(k_host); @@ -716,8 +844,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}(v_host); ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}( vnew_host); - ck_tile::FillUniformDistribution{ - -bias_dtype_max, bias_dtype_max, next_seed()}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(bias_host); } if(bias.type == bias_enum::alibi) { @@ -737,7 +864,41 @@ fwd_result fmha_fwd_run(mode_enum mode, } } } - if(qscale.type == quant_scale_enum::pertensor) + if constexpr(is_mx) + { + auto gen_scales = [&](auto& scales, auto data, float range) { + using DataType = decltype(data); + using ScaleType = ck_tile::remove_cvref_t; + if constexpr(std::is_same_v) + { + const float base = + -std::log2(ck_tile::type_convert(ck_tile::numeric::max())); + // e8m0_t is basically an exponent of float32 + // When scales are applied to tensor values, value * exp2(base - range) is around + // 0.125 and value * exp2(base + range) is around 8 for all types (fp8/bf8/fp4) + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{ + base - range, base + range, next_seed()}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = ck_tile::type_convert(std::exp2(pow2(i))); + }); + } + else + { + static_assert(false); + } + }; + gen_scales(q_descale_host, QDataType{}, 3); + gen_scales(k_descale_host, KDataType{}, 3); + // When P is fp4, only 8 values (0, 0.5, 1, 1.5, 2, 3, 4, 6) are used to quantize P. + // Too large V values can create rare error outliers between host (no quantization) and + // device ("running" FA softmax + quantization), here we reduce max value by using smaller + // range of V scales. + gen_scales(v_descale_host, + VDataType{}, + std::is_same_v ? 1 : 3); + } + else if(qscale.type == quant_scale_enum::pertensor) { float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); @@ -790,12 +951,13 @@ fwd_result fmha_fwd_run(mode_enum mode, sizeof(int32_t)); ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem scale_seqstart_v_buf(seqstart_v_scale_host.size() * sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_buf(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_buf(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() ? 0 : seqstart_q_with_padding_host.size() * @@ -837,9 +999,10 @@ fwd_result fmha_fwd_run(mode_enum mode, v_descale_buf.ToDevice(v_descale_host.data()); block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data()); block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data()); - seqstart_q.ToDevice(seqstart_q_host.data()); - // Keep logical starts in seqstart_k; pass padded K via separate pointer - seqstart_k.ToDevice(seqstart_k_host.data()); + scale_seqstart_v_buf.ToDevice(seqstart_v_scale_host.data()); + seqstart_q_buf.ToDevice(seqstart_q_host.data()); + // Keep logical starts in seqstart_k_buf; pass padded K via separate pointer + seqstart_k_buf.ToDevice(seqstart_k_host.data()); seqstart_q_padded_buf.ToDevice( seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr @@ -1145,7 +1308,13 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - if(qscale.type == quant_scale_enum::blockscale) + if(qscale.type == quant_scale_enum::pertensor) + { + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + } + else if(qscale.type == quant_scale_enum::blockscale) { args.q_descale_ptr = reinterpret_cast(q_descale_buf.GetDeviceBuffer()); @@ -1172,11 +1341,32 @@ fwd_result fmha_fwd_run(mode_enum mode, args.block_scale_size_q = block_scale_size_q_; args.block_scale_size_kv = block_scale_size_kv_; } - else + else if(qscale.type == quant_scale_enum::mx) { args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + + args.stride_q_descale = (i_perm ? hdim_q_scale : nhead * hdim_q_scale); + args.stride_k_descale = (i_perm ? hdim_q_scale : nhead_k * hdim_q_scale); + args.stride_v_descale = + (i_perm ? shape_seqlen_v_scale : nhead_k * shape_seqlen_v_scale); + args.nhead_stride_q_descale = + (i_perm ? shape_seqlen_q * hdim_q_scale : hdim_q_scale); + args.nhead_stride_k_descale = + (i_perm ? shape_seqlen_k * hdim_q_scale : hdim_q_scale); + args.nhead_stride_v_descale = + (i_perm ? hdim_v * shape_seqlen_v_scale : shape_seqlen_v_scale); + if(mode == mode_enum::group) + { + args.seqstart_v_scale_ptr = scale_seqstart_v_buf.GetDeviceBuffer(); + } + else + { + args.batch_stride_q_descale = (nhead * shape_seqlen_q * hdim_q_scale); + args.batch_stride_k_descale = (nhead_k * shape_seqlen_k * hdim_q_scale); + args.batch_stride_v_descale = (nhead_k * hdim_v * shape_seqlen_v_scale); + } } args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1207,11 +1397,11 @@ fwd_result fmha_fwd_run(mode_enum mode, args.seqstart_q_ptr = has_group_q_padding && !seqstart_q_with_padding_host.empty() ? seqstart_q_padded_buf.GetDeviceBuffer() - : seqstart_q.GetDeviceBuffer(); + : seqstart_q_buf.GetDeviceBuffer(); args.seqstart_k_ptr = has_group_k_padding && !seqstart_k_with_padding_host.empty() ? seqstart_k_padded_buf.GetDeviceBuffer() - : seqstart_k.GetDeviceBuffer(); + : seqstart_k_buf.GetDeviceBuffer(); // Logical (unpadded) per-sequence lengths, used when padding is enabled args.seqlen_q_ptr = @@ -1272,9 +1462,9 @@ fwd_result fmha_fwd_run(mode_enum mode, args.split_stride_o_acc = split_stride_o_acc; args.seqstart_q_ptr = - (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + (mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = - (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + (mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr); args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() @@ -1292,9 +1482,9 @@ fwd_result fmha_fwd_run(mode_enum mode, (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); args.seqstart_q_ptr = - (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + (mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = - (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + (mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr); args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() @@ -1462,11 +1652,14 @@ fwd_result fmha_fwd_run(mode_enum mode, float scale_p_host = 1.0f; float scale_o_host = 1.0f; - if(qscale.type == quant_scale_enum::pertensor) + if constexpr(!is_mx) { - scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0); - scale_p_host = ck_tile::type_convert(ck_tile::numeric::max()); - scale_o_host = v_descale_host(0) / scale_p_host; + if(qscale.type == quant_scale_enum::pertensor) + { + scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0); + scale_p_host = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o_host = v_descale_host(0) / scale_p_host; + } } auto p_compute_element_func = [&]() { @@ -1680,7 +1873,45 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif // reference - if(qscale.type == quant_scale_enum::blockscale) + if constexpr(is_mx) + { + ck_tile::HostTensor q_descale_host_ref( + {nhead, real_seqlen_q, hdim_q_scale}); + ck_tile::HostTensor k_descale_host_ref( + {nhead, real_seqlen_k, hdim_q_scale}); + + // clang-format off + if(i_perm) q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[0], i[1] + query_offset, i[2]); }); + else q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + + // clang-format off + if(i_perm) k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + // clang-format on + + auto q_host_ref2 = ck_tile::reference_batched_mx_descale( + q_host_ref, q_descale_host_ref, kQKScaleGranularity); + auto k_host_ref2 = ck_tile::reference_batched_mx_descale( + k_host_ref, k_descale_host_ref, kQKScaleGranularity); + + ck_tile::reference_batched_gemm(q_host_ref2, + k_host_ref2, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); + } + else if(qscale.type == quant_scale_enum::blockscale) { const ck_tile::index_t q_offset = (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; @@ -1881,6 +2112,32 @@ fwd_result fmha_fwd_run(mode_enum mode, s_host_ref, p_host_ref, p_compute_element_func); } } + if(lse) + { + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + }); + + // Use smaller rtol/atol as LSE is computed and stored in fp32, so there is no + // precision loss due to conversion + bool cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + 1e-4, + 1e-4, + /* allow_infinity_ref = */ true); + + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + } + } if(p_drop > 0) { ck_tile::HostTensor randval_host_ref( @@ -1907,13 +2164,43 @@ fwd_result fmha_fwd_run(mode_enum mode, randval_host_ref, "DROPOUT RANDVAL Error: Incorrect results!"); pass &= cur_pass; - if(!cur_pass) - { - break; - } } - if(qscale.type == quant_scale_enum::blockscale) + if constexpr(is_mx) + { + const ck_tile::index_t real_seqlen_v_scale = + seqstart_v_scale_host[wb + 1] - seqstart_v_scale_host[wb]; + const ck_tile::index_t v_scale_offset = + mode == mode_enum::batch ? 0 : seqstart_v_scale_host[wb]; + + ck_tile::HostTensor v_descale_host_ref( + {nhead, hdim_v, real_seqlen_v_scale}); + + // clang-format off + if(i_perm) v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[0] / nr, i[1], i[2] + v_scale_offset); }); + else v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[1], i[0] / nr, i[2] + v_scale_offset); }); + // clang-format on + + auto v_host_ref2 = ck_tile::reference_batched_mx_descale( + v_host_ref, v_descale_host_ref, kVScaleGranularity); + + // P is not quantized and then dequantized here (PDataType = float). + // On host softmax is computed for the whole row of S, while on device FA computes + // softmax and quantizes it in blocks of N0 values. Quantization on host would make + // reference results **less** precise than the device results for large seqlen_k! + + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref2, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + } + else if(qscale.type == quant_scale_enum::blockscale) { const ck_tile::index_t v_offset = (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; @@ -1969,35 +2256,11 @@ fwd_result fmha_fwd_run(mode_enum mode, << std::endl << "\tquery_offset used: " << query_offset << std::endl << "\tkey_offset used: " << key_offset << std::endl; - - break; } - if(lse) + if(!pass) { - ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); - }); - - cur_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); - - pass &= cur_pass; - if(!cur_pass) - { - std::cerr << "LSE mismatch found at batch: " << wb << std::endl - << "\tseqlen_q: " << real_seqlen_q << std::endl - << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; - - break; - } + break; } } @@ -2006,33 +2269,37 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results( - *json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale - ? "no_scale" - : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + const std::string qscale_name = + (qscale.type == quant_scale_enum::no_scale ? "no_scale" + : qscale.type == quant_scale_enum::pertensor ? "pertensor" + : qscale.type == quant_scale_enum::blockscale ? "blockscale" + : qscale.type == quant_scale_enum::kv_blockscale ? "kv_blockscale" + : qscale.type == quant_scale_enum::mx ? "mx" + : "unknown"); + dump_fmha_fwd_json_results(*json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale_name, + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index 833a025f79..4b8cd2e9a4 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -18,6 +18,7 @@ enum class quant_scale_enum pertensor = 1, blockscale = 2, kv_blockscale = 3, // Q per-tensor, K/V per-page block scale + mx = 4, // Microscaling (MX) }; struct quant_scale_info @@ -34,6 +35,8 @@ struct quant_scale_info os << "bs"; else if(type == quant_scale_enum::kv_blockscale) os << "kvbs"; + else if(type == quant_scale_enum::mx) + os << "mx"; } static quant_scale_info decode(std::string str) @@ -55,6 +58,10 @@ struct quant_scale_info { info.type = quant_scale_enum::kv_blockscale; } + else if(str == "mx" || str == "4") + { + info.type = quant_scale_enum::mx; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 246b2b85a7..f7dc610717 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -2693,7 +2693,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, else { thread_buffer tmp; - tmp.template set_as(number<0>{}, vector_t{customized_value}); + tmp.template set_as( + number<0>{}, vector_t{static_cast(customized_value)}); return tmp; } } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 366dc7a6d8..545ef73e46 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2519,7 +2519,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, else { thread_buffer tmp; - tmp.template set_as(number<0>{}, vector_t{customized_value}); + tmp.template set_as( + number<0>{}, vector_t{static_cast(customized_value)}); return tmp; } } diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index f04879f7cd..995d854536 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -24,6 +24,7 @@ #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" +#include "ck_tile/host/reference/reference_batched_mx_descale.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_transpose.hpp" diff --git a/include/ck_tile/host/reference/reference_batched_mx_descale.hpp b/include/ck_tile/host/reference/reference_batched_mx_descale.hpp new file mode 100644 index 0000000000..5a47f0eee0 --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_mx_descale.hpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST HostTensor +reference_batched_mx_descale(const HostTensor& a_b_m_k, + const HostTensor& scales_b_m_ks, + const std::size_t scale_granularity) +{ + const std::size_t B = a_b_m_k.get_length(0); + const std::size_t M = a_b_m_k.get_length(1); + const std::size_t K = a_b_m_k.get_length(2); + + HostTensor a_b_m_k_scaled(a_b_m_k.get_lengths()); + + auto f = [&](auto batch) { + constexpr index_t packed_size = ck_tile::numeric_traits::PackedSize; + + for(std::size_t m = 0; m < M; ++m) + { + for(std::size_t k = 0; k < K; k += packed_size) + { + const auto scale = ck_tile::type_convert( + scales_b_m_ks(batch, m, k / scale_granularity)); + + if constexpr(std::is_same_v) + { + auto a_f4x2 = a_b_m_k(batch, m, k); + auto a_f4_lo = ck_tile::type_convert( + a_f4x2.template unpack<>(number<0>{})); + auto a_f4_hi = ck_tile::type_convert( + a_f4x2.template unpack<>(number<1>{})); + + a_b_m_k_scaled(batch, m, k) = a_f4_lo * scale; + a_b_m_k_scaled(batch, m, k + 1) = a_f4_hi * scale; + } + else + { + a_b_m_k_scaled(batch, m, k) = + ck_tile::type_convert(a_b_m_k(batch, m, k)) * scale; + } + } + } + }; + make_ParallelTensorFunctor(f, B)(std::thread::hardware_concurrency()); + + return a_b_m_k_scaled; +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 0639fa1b36..8a5d77bf46 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/fmha/block/cast_tile_mx.hpp" #include "ck_tile/ops/fmha/block/page_block_navigator.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include "ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 0c6075e063..61051cc08a 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -14,6 +14,7 @@ enum class BlockAttentionQuantScaleEnum PERTENSOR = 1, BLOCKSCALE = 2, KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale + MX = 4, // Microscaling }; template @@ -34,5 +35,15 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "kv_blockscale"; +}; +template <> +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "mx"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/cast_tile_mx.hpp b/include/ck_tile/ops/fmha/block/cast_tile_mx.hpp new file mode 100644 index 0000000000..6efc1055bd --- /dev/null +++ b/include/ck_tile/ops/fmha/block/cast_tile_mx.hpp @@ -0,0 +1,186 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE void +cast_tile_mx(DstTensor& dst_tensor, DstScaleTensor& dst_scale_tensor, const SrcTensor& src_tensor) +{ + using DstDataType = remove_cv_t; + using DstScaleDataType = remove_cv_t; + + static_assert(SrcTensor::get_thread_buffer_size() == + DstScaleTensor::get_thread_buffer_size() * ScaleGranularity); + + constexpr index_t size = SrcTensor::get_thread_buffer_size(); + + const auto src_thread_buffer = cast_tile(src_tensor).get_thread_buffer(); + + if constexpr(std::is_same_v) + { + static_for<0, size / 32, 1>{}([&](auto i) { + // Maximum of consecutive ScaleGranularity values + // (1 lane, 32 per lane for fp4) + float max_abs = 0; + static_for<0, 32, 1>{}([&](auto j) { + max_abs = max(max_abs, abs(src_thread_buffer[number{}])); + }); + + static_assert(std::is_same_v); + // Use literal because type_convert(numeric::max()) is not constexpr + // causing the result of div to be stored in a VGPR + constexpr float rcp_dst_max = 1.0f / 6.0f; + // For e8m0 scales round up to the next power of 2, equivalent of exp2(ceil(log2(x))) + float scale = bit_cast( + (bit_cast(max_abs * rcp_dst_max) + numeric_traits::mant_mask) & + numeric_traits::head_mask); + + // Convert using scales + + static_for<0, 32 / 8, 1>{}([&](auto j) { + using vec_t = uint32_t; + // These builtins require the old value, and will generate a v_mov_b32 + // vxxx [old] before cvt, which result in unwanted ISA so we prepare an + // uninitialized variable x purposely, and turn off the warning +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + vec_t x; + x = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + 0); // byte 0 + x = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + 1); // byte 1 + x = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + 2); // byte 2 + x = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + 3); // byte 3 + dst_tensor.get_thread_buffer().template set_as(number{}, x); +#pragma clang diagnostic pop + }); + + // Save scale for the corresponding lane + // No additional processing is needed because each lane computes scale based only on its + // own values. + dst_scale_tensor.get_thread_buffer()(i) = type_convert(scale); + }); + } + else + { + const index_t lane = __lane_id(); + float scale_result = 0; + static_for<0, size / 16, 1>{}([&](auto i) { + // Maximum of consecutive ScaleGranularity values + // (2 lanes, 16 per lane for fp8/bf8) + float max_abs = 0; + static_for<0, 16, 1>{}([&](auto j) { + max_abs = max(max_abs, abs(src_thread_buffer[number{}])); + }); + // 2 lanes, 16 values per lane share one scale + max_abs = max(max_abs, warp_shuffle(max_abs, lane ^ MLane)); + + static_assert(std::is_same_v); + // Use literal because type_convert(numeric::max()) is not constexpr + // causing the result of div to be stored in a VGPR + constexpr float rcp_dst_max = + 1.0f / (std::is_same_v ? 448.0f : 57344.0f); + // For e8m0 scales round up to the next power of 2, equivalent of exp2(ceil(log2(x))) + float scale = bit_cast( + (bit_cast(max_abs * rcp_dst_max) + numeric_traits::mant_mask) & + numeric_traits::head_mask); + + // Convert using scales + + static_for<0, 16 / 4, 1>{}([&](auto j) { + using vec_t = ext_vector_t; + // These builtins require the old value, and will generate a v_mov_b32 + // vxxx [old] before cvt, which result in unwanted ISA so we prepare an + // uninitialized variable x purposely, and turn off the warning +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + vec_t x; + if constexpr(std::is_same_v) + { + x = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + false); // false -> WORD0 + x = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + true); // true -> WORD1 + } + else + { + x = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + false); // false -> WORD0 + x = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + x, + src_thread_buffer[number{}], + src_thread_buffer[number{}], + scale, + true); // true -> WORD1 + } + dst_tensor.get_thread_buffer().template set_as(number{}, x); +#pragma clang diagnostic pop + }); + + // Save scale for the corresponding lane + // Two iterations are needed to compute scales for all kABKLane lanes. + // 32x32x64, 2 lanes per row (kABKLane = 2): + // scale_result for lanes 00..31 <- scale for lanes 00..31, iteration 0 + // scale_result for lanes 32..63 <- scale for lanes 32..63, iteration 1 + // 16x16x128, 4 lanes per row (kABKLane = 4), one extra exchange is needed: + // scale_result for lanes 00..15 <- scale for lanes 00..31, iteration 0 + // scale_result for lanes 16..31 <- scale for lanes 32..63, iteration 0 + // scale_result for lanes 32..47 <- scale for lanes 00..31, iteration 1 + // scale_result for lanes 48..64 <- scale for lanes 32..63, iteration 1 + if constexpr(MLane == 16) // 16x16x128 + { + scale = warp_shuffle(scale, (lane % MLane) | ((lane & MLane) << 1)); + } + if((i % 2 == 0) == (lane < 32)) + { + scale_result = scale; + } + if constexpr(i % 2 == 1) + { + dst_scale_tensor.get_thread_buffer()(number{}) = + type_convert(scale_result); + } + }); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0039c57cfc..bd09453dbb 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -191,6 +191,29 @@ struct FmhaFwdKernel const int32_t* block_scale_seqstart_k_ptr; }; + struct FmhaFwdCommonMXKargs : FmhaFwdCommonQScaleKargs + { + ck_tile::index_t stride_q_descale; + ck_tile::index_t stride_k_descale; + ck_tile::index_t stride_v_descale; + + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + }; + + struct FmhaFwdBatchMXKargs : FmhaFwdCommonMXKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + }; + + struct FmhaFwdGroupMXKargs : FmhaFwdCommonMXKargs + { + const int32_t* seqstart_v_scale_ptr; + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -271,7 +294,9 @@ struct FmhaFwdKernel FmhaFwdCommonQScaleKargs, std::conditional_t>>, + std::conditional_t>>>, std::conditional_t>, std::conditional_t> { @@ -300,7 +325,9 @@ struct FmhaFwdKernel FmhaFwdCommonQScaleKargs, std::conditional_t>>, + std::conditional_t>>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -350,6 +377,9 @@ struct FmhaFwdKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, + ck_tile::index_t stride_q_descale, + ck_tile::index_t stride_k_descale, + ck_tile::index_t stride_v_descale, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, @@ -450,7 +480,7 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { kargs.q_descale_ptr = q_descale_ptr; kargs.k_descale_ptr = k_descale_ptr; @@ -467,6 +497,24 @@ struct FmhaFwdKernel kargs.block_scale_size_q = block_scale_size_q; kargs.block_scale_size_kv = block_scale_size_kv; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.stride_q_descale = stride_q_descale; + kargs.stride_k_descale = stride_k_descale; + kargs.stride_v_descale = stride_v_descale; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -525,6 +573,9 @@ struct FmhaFwdKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, + ck_tile::index_t stride_q_descale, + ck_tile::index_t stride_k_descale, + ck_tile::index_t stride_v_descale, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, @@ -583,6 +634,9 @@ struct FmhaFwdKernel stride_bias, stride_randval, stride_o, + stride_q_descale, + stride_k_descale, + stride_v_descale, nhead_stride_q, nhead_stride_k, nhead_stride_v, @@ -644,6 +698,9 @@ struct FmhaFwdKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, + ck_tile::index_t stride_q_descale, + ck_tile::index_t stride_k_descale, + ck_tile::index_t stride_v_descale, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, @@ -702,6 +759,9 @@ struct FmhaFwdKernel stride_bias, stride_randval, stride_o, + stride_q_descale, + stride_k_descale, + stride_v_descale, nhead_stride_q, nhead_stride_k, nhead_stride_v, @@ -754,6 +814,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, const void* block_scale_seqstart_q_ptr, const void* block_scale_seqstart_k_ptr, + const void* seqstart_v_scale_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -766,6 +827,9 @@ struct FmhaFwdKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, + ck_tile::index_t stride_q_descale, + ck_tile::index_t stride_k_descale, + ck_tile::index_t stride_v_descale, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, @@ -856,7 +920,7 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) { kargs.q_descale_ptr = q_descale_ptr; kargs.k_descale_ptr = k_descale_ptr; @@ -874,6 +938,22 @@ struct FmhaFwdKernel kargs.block_scale_seqstart_k_ptr = reinterpret_cast(block_scale_seqstart_k_ptr); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.stride_q_descale = stride_q_descale; + kargs.stride_k_descale = stride_k_descale; + kargs.stride_v_descale = stride_v_descale; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.seqstart_v_scale_ptr = reinterpret_cast(seqstart_v_scale_ptr); + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -939,6 +1019,9 @@ struct FmhaFwdKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, + ck_tile::index_t stride_q_descale, + ck_tile::index_t stride_k_descale, + ck_tile::index_t stride_v_descale, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, @@ -992,6 +1075,9 @@ struct FmhaFwdKernel stride_bias, stride_randval, stride_o, + stride_q_descale, + stride_k_descale, + stride_v_descale, nhead_stride_q, nhead_stride_k, nhead_stride_v, @@ -1048,6 +1134,9 @@ struct FmhaFwdKernel ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, + ck_tile::index_t stride_q_descale, + ck_tile::index_t stride_k_descale, + ck_tile::index_t stride_v_descale, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, @@ -1101,6 +1190,9 @@ struct FmhaFwdKernel stride_bias, stride_randval, stride_o, + stride_q_descale, + stride_k_descale, + stride_v_descale, nhead_stride_q, nhead_stride_k, nhead_stride_v, @@ -1303,6 +1395,12 @@ struct FmhaFwdKernel batch_offset_k_descale = bkey_start; batch_offset_v_descale = bkey_start; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + batch_offset_q_descale = query_start * kargs.stride_q_descale; + batch_offset_k_descale = key_start * kargs.stride_k_descale; + batch_offset_v_descale = kargs.seqstart_v_scale_ptr[i_batch]; + } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1370,7 +1468,8 @@ struct FmhaFwdKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockAttentionQuantScaleEnum::MX) { batch_offset_q_descale = static_cast(i_batch) * kargs.batch_stride_q_descale; @@ -1395,17 +1494,20 @@ struct FmhaFwdKernel } // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; + const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk; + + const QDataType* q_ptr = + reinterpret_cast(kargs.q_ptr) + + (static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q) / + numeric_traits::PackedSize; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; + (static_cast(i_nhead_k) * kargs.nhead_stride_k + batch_offset_k) / + numeric_traits::PackedSize; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - batch_offset_v; + (static_cast(i_nhead_k) * kargs.nhead_stride_v + batch_offset_v) / + numeric_traits::PackedSize; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o + batch_offset_o; @@ -1698,9 +1800,9 @@ struct FmhaFwdKernel } }(); - BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + BlockIndices block_indices{i_batch, i_nhead, i_nhead_k}; - auto o_acc_tile = [&, i_nhead_ = i_nhead]() { + auto o_acc_tile = [&, i_nhead_ = i_nhead, i_nhead_k_ = i_nhead_k]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline @@ -1744,6 +1846,9 @@ struct FmhaFwdKernel nullptr, nullptr, 1, + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), sink_value); } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) @@ -1795,8 +1900,144 @@ struct FmhaFwdKernel k_descale_ptr, v_descale_ptr, kargs.block_scale_size_kv, + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), sink_value); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + using QScaleDataType = typename FmhaPipeline::QScaleDataType; + using KScaleDataType = typename FmhaPipeline::KScaleDataType; + using VScaleDataType = typename FmhaPipeline::VScaleDataType; + + constexpr ck_tile::index_t kQKScaleGranularity = + FmhaPipeline::kQKScaleGranularity; + constexpr ck_tile::index_t kVScaleGranularity = + FmhaPipeline::kVScaleGranularity; + + const QScaleDataType* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const KScaleDataType* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead_k_) * kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const VScaleDataType* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead_k_) * kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + const ck_tile::index_t hdim_q_scale = + ck_tile::integer_divide_ceil(kargs.hdim_q, kQKScaleGranularity); + const ck_tile::index_t seqlen_v_scale = + ck_tile::integer_divide_ceil(kargs.seqlen_k, kVScaleGranularity); + + // Custom invalid_element_value is required for e8m0_t scales because + // the default (numeric>::zero()) is NaN + const auto q_scale_dram = [&]() { + auto desc = + make_naive_tensor_descriptor(make_tuple(kargs.seqlen_q, hdim_q_scale), + make_tuple(kargs.stride_q_descale, 1), + number<1>{}, + number<1>{}); + auto buffer_view = make_buffer_view( + q_descale_ptr, + desc.get_element_space_size(), + type_convert(1.0f)); + return pad_tensor_view( + tensor_view{buffer_view, desc}, + make_tuple( + number{}, + number<(FmhaPipeline::kQLoadOnce ? FmhaPipeline::kSubQKHeaddim + : FmhaPipeline::kK0) / + kQKScaleGranularity>{}), + sequence{}); + }(); + const auto k_scale_dram = [&]() { + auto desc = + make_naive_tensor_descriptor(make_tuple(kargs.seqlen_k, hdim_q_scale), + make_tuple(kargs.stride_k_descale, 1), + number<1>{}, + number<1>{}); + auto buffer_view = make_buffer_view( + k_descale_ptr, + desc.get_element_space_size(), + type_convert(1.0f)); + return pad_tensor_view( + tensor_view{buffer_view, desc}, + make_tuple(number{}, + number{}), + sequence{}); + }(); + const auto v_scale_dram = [&]() { + static_assert( + std::is_same_v); + auto desc = + make_naive_tensor_descriptor(make_tuple(kargs.hdim_v, seqlen_v_scale), + make_tuple(kargs.stride_v_descale, 1), + number<1>{}, + number<1>{}); + auto buffer_view = make_buffer_view( + v_descale_ptr, + desc.get_element_space_size(), + type_convert(1.0f)); + return pad_tensor_view( + tensor_view{buffer_view, desc}, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + auto q_scale_dram_window = make_tile_window( + q_scale_dram, + make_tuple(number{}, + number<(FmhaPipeline::kQLoadOnce ? FmhaPipeline::kSubQKHeaddim + : FmhaPipeline::kK0) / + kQKScaleGranularity>{}), + {i_m0, 0}); + auto k_scale_dram_window = make_tile_window( + k_scale_dram, + make_tuple(number{}, + number{}), + {0, 0}); + auto v_scale_dram_window = make_tile_window( + v_scale_dram, + make_tuple(number{}, + number{}), + {i_n1, 0}); + + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + identity{}, // p_compute_element_func + identity{}, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + nullptr, + nullptr, + 1, + q_scale_dram_window, + k_scale_dram_window, + v_scale_dram_window, + sink_value); + } else { return FmhaPipeline{}(q_dram_window, @@ -1969,15 +2210,18 @@ struct FmhaFwdKernel // for simplicity, batch stride we just modify the pointer const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk; - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead_k) * kargs.nhead_stride_v + - batch_offset_v; + const QDataType* q_ptr = + reinterpret_cast(kargs.q_ptr) + + (static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q) / + numeric_traits::PackedSize; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + (static_cast(i_nhead_k) * kargs.nhead_stride_k + batch_offset_k) / + numeric_traits::PackedSize; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + (static_cast(i_nhead_k) * kargs.nhead_stride_v + batch_offset_v) / + numeric_traits::PackedSize; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o + @@ -2006,7 +2250,8 @@ struct FmhaFwdKernel make_tuple(number{}, number{}), sequence{}); #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(QDataType); + constexpr index_t LDSLayerSize = + 256 * numeric_traits::PackedSize / sizeof(QDataType); constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); if constexpr(XorLengthFold > 1) @@ -2130,7 +2375,8 @@ struct FmhaFwdKernel FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0; #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); + constexpr index_t LDSLayerSize = + 256 * numeric_traits::PackedSize / sizeof(KDataType); constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); if constexpr(XorLengthFold > 1) @@ -2254,7 +2500,8 @@ struct FmhaFwdKernel sequence{}); #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD - constexpr index_t LDSLayerSize = 256 / sizeof(VDataType); + constexpr index_t LDSLayerSize = + 256 * numeric_traits::PackedSize / sizeof(VDataType); constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); if constexpr(XorLengthFold > 1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index eabf74faf8..87db7b85b9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -44,6 +44,15 @@ struct BlockFmhaPipelineProblem using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; + // TODO: Pass scale types and granularity from FmhaFwdTypeConfig + using QScaleDataType = ck_tile::e8m0_t; + using KScaleDataType = ck_tile::e8m0_t; + using VScaleDataType = ck_tile::e8m0_t; + using PScaleDataType = ck_tile::e8m0_t; + + static constexpr ck_tile::index_t kQKScaleGranularity = 32; + static constexpr ck_tile::index_t kVScaleGranularity = 32; + static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 35654840bd..b207c62181 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/block/cast_tile_mx.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -29,6 +30,10 @@ struct BlockFmhaPipelineQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using QScaleDataType = remove_cvref_t; + using KScaleDataType = remove_cvref_t; + using VScaleDataType = remove_cvref_t; + using PScaleDataType = remove_cvref_t; using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; @@ -61,6 +66,9 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; + static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity; + static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] static constexpr float OCP_FP8_SHIFT = 8.0f; static constexpr float FNUZ_FP8_SHIFT = 7.0f; @@ -75,15 +83,16 @@ struct BlockFmhaPipelineQRKSVS // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits::PackedSize + : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits::PackedSize + : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = []() { if constexpr(std::is_same_v) return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + return kPadSeqLenK ? numeric_traits::PackedSize + : Policy::template GetAlignmentV(); }(); static constexpr index_t kAlignmentO = @@ -149,7 +158,10 @@ struct BlockFmhaPipelineQRKSVS typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, - typename BlockIndices> + typename BlockIndices, + typename QScaleDramBlockWindowTmp, + typename KScaleDramBlockWindowTmp, + typename VScaleDramBlockWindowTmp> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -176,6 +188,12 @@ struct BlockFmhaPipelineQRKSVS const float* k_descale_ptr, const float* v_descale_ptr, const index_t block_scale_size_kv, + const QScaleDramBlockWindowTmp& + q_scale_dram_block_window_tmp, // M0*(K0/kQKScaleGranularity) tile + const KScaleDramBlockWindowTmp& + k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile + const VScaleDramBlockWindowTmp& + v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile const float sink_v) const { static_assert( @@ -185,6 +203,8 @@ struct BlockFmhaPipelineQRKSVS "wrong!"); static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kSubQKHeaddim == + QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -193,6 +213,29 @@ struct BlockFmhaPipelineQRKSVS kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + static_assert(std::is_same_v); + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>); + static_assert(kM0 == QScaleDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kSubQKHeaddim == + QScaleDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] * + kQKScaleGranularity && + kN0 == KScaleDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KScaleDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] * + kQKScaleGranularity && + kN1 == VScaleDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VScaleDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] * + kVScaleGranularity); + } + // K tile in LDS KDataType* k_lds_ptr = static_cast(static_cast( static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); @@ -331,13 +374,54 @@ struct BlockFmhaPipelineQRKSVS auto q_tile = tile_elementwise_in(q_element_func, q); + auto q_scale = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + auto q_scale_dram_window = + make_tile_window(q_scale_dram_block_window_tmp.get_bottom_tensor_view(), + q_scale_dram_block_window_tmp.get_window_lengths(), + q_scale_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQScaleRegTileDistribution()); + return load_tile(q_scale_dram_window); + } + else + { + return null_tensor{}; + } + }(); + auto k_scale_dram_block_window = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + return make_tile_window(k_scale_dram_block_window_tmp.get_bottom_tensor_view(), + k_scale_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + } + else + { + return make_null_tile_window(make_tuple()); + } + }(); + auto v_scale_dram_window = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + return make_tile_window(v_scale_dram_block_window_tmp.get_bottom_tensor_view(), + v_scale_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start / kVScaleGranularity}, + Policy::template MakeVScaleRegTileDistribution()); + } + else + { + return make_null_tile_window(make_tuple()); + } + }(); + // prefetch K tile index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; // Use compile-time conditional for group barrier sequence // (No runtime lambda selection) - auto schedule_gemm0 = [] { + auto schedule_gemm_0 = [] { using BlockGemm0 = remove_cvref_t; constexpr auto WarpGemmConfig = BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); @@ -381,6 +465,32 @@ struct BlockFmhaPipelineQRKSVS k_dram_block_window.get_window_origin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load + auto k_scale_dram_window = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + return make_tile_window( + k_scale_dram_block_window.get_bottom_tensor_view(), + k_scale_dram_block_window.get_window_lengths(), + k_scale_dram_block_window.get_window_origin(), + Policy::template MakeKScaleRegTileDistribution()); + } + else + { + return make_null_tile_window(make_tuple()); + } + }(); + auto load_k_scale_block_tile = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + auto t = load_tile(k_scale_dram_window); + move_tile_window(k_scale_dram_window, {0, kK0 / kQKScaleGranularity}); + return t; + } + else + { + return make_null_tile_window(make_tuple()); + } + }; auto k_block_tile = load_tile(k_dram_window); { @@ -389,6 +499,7 @@ struct BlockFmhaPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } + auto k_scale_block_tile = load_k_scale_block_tile(); if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -402,16 +513,29 @@ struct BlockFmhaPipelineQRKSVS 0); // prevent from messing up the order of global loads } + auto run_gemm_0 = [&](auto i_k0) { + auto q_slice = get_slice_tile( + q_tile, sequence<0, i_k0 * kK0>{}, sequence{}); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + auto q_scale_slice = + get_slice_tile(q_scale, + sequence<0, i_k0*(kK0 / kQKScaleGranularity)>{}, + sequence{}); + gemm_0(s_acc, q_slice, q_scale_slice, k_lds_window, k_scale_block_tile); + } + else + { + gemm_0(s_acc, q_slice, k_lds_window); + schedule_gemm_0(); + } + }; + if constexpr(k0_loops > 2) { static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_window); - schedule_gemm0(); + run_gemm_0(number{}); block_sync_lds(); move_tile_window(k_dram_window, {0, kK0}); @@ -419,29 +543,24 @@ struct BlockFmhaPipelineQRKSVS k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 k_block_tile = load_tile(k_dram_window); // global read i + 2 + + k_scale_block_tile = load_k_scale_block_tile(); }); } const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile { // tail block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 2) * kK0>{}, - sequence{}), - k_lds_window); - schedule_gemm0(); + run_gemm_0(number{}); block_sync_lds(); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + + k_scale_block_tile = load_k_scale_block_tile(); + block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_window); - schedule_gemm0(); + run_gemm_0(number{}); } // dequant auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { @@ -718,15 +837,19 @@ struct BlockFmhaPipelineQRKSVS move_tile_window(v_dram_window, {0, kK1}); -#if defined(__gfx11__) - auto p = make_static_distributed_tensor( - decltype(gemm_1)::template MakeABlockTileDistribution()); - PermuteWarpGemmCToA( - p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); -#else - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); -#endif + auto load_v_scale_block_tile = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + auto t = load_tile(v_scale_dram_window); + move_tile_window(v_scale_dram_window, {0, kK1 / kVScaleGranularity}); + return t; + } + else + { + return make_null_tile_window(make_tuple()); + } + }; + auto v_scale_block_tile = load_v_scale_block_tile(); float v_descale = 1.0f; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) @@ -735,29 +858,73 @@ struct BlockFmhaPipelineQRKSVS const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; v_descale = v_descale_ptr[kv_idx]; } + + const auto p_p_scale = [&] { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + auto p_result = make_static_distributed_tensor( + p_compute.get_tile_distribution()); + auto p_scale_result = make_static_distributed_tensor( + Policy::template MakePScaleRegTileDistribution()); + + constexpr auto config = + decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + cast_tile_mx( + p_result, p_scale_result, p_compute); + + return make_tuple(p_result, p_scale_result); + } + else + { +#if defined(__gfx11__) + auto p_result = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA(p_result, + cast_tile(tile_elementwise_in( + p_compute_element_func, p_compute))); +#else + const auto p_result = cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); +#endif + return make_tuple(p_result, null_tensor{}); + } + }(); + const auto p = p_p_scale[number<0>{}]; + const auto p_scale = p_p_scale[number<1>{}]; + // STAGE 3, KV gemm auto o_acc0 = decltype(o_acc){}; clear_tile(o_acc0); - auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + auto run_gemm_1 = [&](auto i_k1) { + auto p_slice = + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) { - return o_acc0; + auto p_scale_slice = + get_slice_tile(p_scale, + sequence<0, i_k1*(kK1 / kVScaleGranularity)>{}, + sequence{}); + gemm_1(o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile); + } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + gemm_1(o_acc0, p_slice, v_lds_window); } else { - return o_acc; + gemm_1(o_acc, p_slice, v_lds_window); } - }(); + }; + if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc_, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_window); + run_gemm_1(number{}); block_sync_lds(); if constexpr(std::is_same_v) { @@ -774,6 +941,7 @@ struct BlockFmhaPipelineQRKSVS tile_elementwise_in(v_element_func, v)); // store next v } move_tile_window(v_dram_window, {0, kK1}); + v_scale_block_tile = load_v_scale_block_tile(); }); } // move K tile windows @@ -786,12 +954,14 @@ struct BlockFmhaPipelineQRKSVS } } move_tile_window(k_dram_block_window, {kN0, 0}); + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + move_tile_window(k_scale_dram_block_window, {kN0, 0}); + } // tail { block_sync_lds(); - gemm_1(o_acc_, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_window); + run_gemm_1(number{}); block_sync_lds(); } if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) @@ -921,6 +1091,9 @@ struct BlockFmhaPipelineQRKSVS nullptr, nullptr, 1, + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index cfd842dc9d..7b97d01fa4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -171,7 +171,10 @@ struct BlockFmhaPipelineQRKSVSAsync typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, - typename BlockIndices> + typename BlockIndices, + typename QScaleDramBlockWindowTmp, + typename KScaleDramBlockWindowTmp, + typename VScaleDramBlockWindowTmp> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -198,6 +201,9 @@ struct BlockFmhaPipelineQRKSVSAsync const float* k_descale_ptr, const float* v_descale_ptr, const index_t block_scale_size_kv, + const QScaleDramBlockWindowTmp&, // M0*(K0/kQKScaleGranularity) tile + const KScaleDramBlockWindowTmp&, // N0*(K0/kQKScaleGranularity) tile + const VScaleDramBlockWindowTmp&, // N1*(K1/kVScaleGranularity) tile const float sink_v) const { static_assert( @@ -215,6 +221,8 @@ struct BlockFmhaPipelineQRKSVSAsync kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + static_assert(QScaleEnum != BlockAttentionQuantScaleEnum::MX); + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); // K tile in LDS @@ -986,6 +994,9 @@ struct BlockFmhaPipelineQRKSVSAsync nullptr, nullptr, 1, + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4acd5d7250..581bcc19d4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -16,9 +16,18 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp" namespace ck_tile { +namespace detail { + +template +using has_qscale_enum_type = decltype(T::QScaleEnum); + +} // namespace detail + template struct BlockFmhaPipelineQXCustomPolicy; @@ -38,7 +47,10 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + using QDataType = remove_cvref_t; + + constexpr index_t MaxVectorSize = + 16 * numeric_traits::PackedSize / sizeof(QDataType); using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); @@ -57,6 +69,24 @@ struct BlockFmhaPipelineQXCustomPolicy Problem::BlockFmhaShape::kSubQKHeaddim>(); } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQScaleRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeAScaleBlockTileDistribution< + Problem::BlockFmhaShape::kM0, + Problem::BlockFmhaShape::kSubQKHeaddim>(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKScaleRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::MakeBScaleBlockTileDistribution(); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { @@ -71,47 +101,109 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(get_warp_size() == 64 && - std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); - - // TODO: hard coded here. Otherwise, it produces incorrect results - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } + constexpr auto QScaleEnum = []() { + if constexpr(is_detected{}) + return Problem::QScaleEnum; else - { - constexpr bool SwizzleA = - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; + return ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE; + }(); + + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + constexpr auto warp_gemm = []() { + static_assert(std::is_same_v == + std::is_same_v); + constexpr auto AttrNumAccess = std::is_same_v + ? WGAttrNumAccessEnum::Single + : WGAttrNumAccessEnum::Double; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true, // TransposeC - SwizzleA>{}; - } - }(); + true, // TransposeC + false, // SwizzleA + false, + AttrNumAccess>{}; + }(); - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV2CustomPolicy; + // Ensure that QKBlockGemm's C (S) can be used as KVBlockGemm's A (P) + constexpr index_t TargetCMPerLane = [] { + // Must be consistent with GetKVBlockGemm() + constexpr auto AttrNumAccess = std::is_same_v + ? WGAttrNumAccessEnum::Single + : WGAttrNumAccessEnum::Double; + using WarpGemm = + WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, // TransposeC + false, // SwizzleA + false, + AttrNumAccess>; + // fp8: kABKPerLane / WGAttrNumAccessEnum::Double = 16 + // fp4: kABKPerLane / WGAttrNumAccessEnum::Single = 32 + return WarpGemm::WarpGemmAttribute::Impl::kABKPerLane / + WarpGemm::WarpGemmAttribute::AttrNumAccessV; + }(); - if constexpr(1 < Problem::kNumGemm0Warps) - return BlockGemmARegBSmemCRegV2{}; + using BlockGemmPolicy = BlockGemmMxARegBSmemCRegV1CustomPolicy< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + decltype(warp_gemm)>; + + return BlockGemmMxARegBSmemCRegV1{}; + } else - return BlockGemmARegBSmemCRegOneWarpV1{}; + { + constexpr auto warp_gemm = []() { + if constexpr(get_warp_size() == 64 && + std::is_same_v && + std::is_same_v && + std::is_same_v && + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32 && + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32 && + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32) + { + // TODO: hard coded here. Otherwise, it produces incorrect results + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else + { + constexpr bool SwizzleA = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; + return WarpGemmDispatcher< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true, // TransposeC + SwizzleA>{}; + } + }(); + + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + decltype(warp_gemm)>; + + if constexpr(1 < Problem::kNumGemm0Warps) + return BlockGemmARegBSmemCRegV2{}; + else + return BlockGemmARegBSmemCRegOneWarpV1{}; + } } }; @@ -123,24 +215,27 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() { + using QDataType = remove_cvref_t; + constexpr index_t lds_alignment = 16; // optional - constexpr index_t q_smem_size = - ck_tile::integer_divide_ceil( - sizeof(typename Problem::QDataType) * - MakeQLdsBlockDescriptor().get_element_space_size(), - lds_alignment) * - lds_alignment; + constexpr index_t q_smem_size = ck_tile::integer_least_multiple( + sizeof(QDataType) * MakeQLdsBlockDescriptor().get_element_space_size() / + numeric_traits::PackedSize, + lds_alignment); return q_smem_size; } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { + using QDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + constexpr index_t MaxVectorSize = + 16 * numeric_traits::PackedSize / sizeof(QDataType); // this should align with MakeQDramTileDistribution() constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; @@ -157,7 +252,8 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); + constexpr index_t MaxVectorSize = + 16 * numeric_traits::PackedSize / sizeof(QDataType); constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; static_assert(0 < ElemPerThread); @@ -187,7 +283,7 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - constexpr index_t kKPack = 16 / sizeof(QDataType); + constexpr index_t kKPack = 16 * numeric_traits::PackedSize / sizeof(QDataType); constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -223,12 +319,11 @@ struct BlockFmhaPipelineQXCustomPolicy if constexpr(get_warp_size() == 64 && std::is_same_v && std::is_same_v && - std::is_same_v) + std::is_same_v && + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32 && + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32 && + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32) { - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); - // TODO: hard coded here. Otherwise, it produces incorrect results constexpr index_t swizzle_factor = 4; return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< @@ -339,7 +434,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - return 16 / sizeof(KDataType); + return 16 * numeric_traits::PackedSize / sizeof(KDataType); } template @@ -354,7 +449,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy::PackedSize / sizeof(KDataType); } else { @@ -362,7 +457,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy::PackedSize / sizeof(KDataType); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; return min(MaxVectorSize, ElemPerThread); @@ -378,8 +474,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(16 / sizeof(VDataType))); + constexpr index_t kMaxVecLoad = min( + total_pixels, + static_cast(16 * numeric_traits::PackedSize / sizeof(VDataType))); return kMaxVecLoad; } @@ -393,12 +490,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(16 / sizeof(VDataType))); + constexpr index_t kMaxVecLoad = min( + total_pixels, + static_cast(16 * numeric_traits::PackedSize / sizeof(VDataType))); if constexpr(std::is_same_v) { - constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + constexpr index_t kMinVecLoad = + 4 * numeric_traits::PackedSize / sizeof(VDataType); constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad @@ -477,10 +576,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - constexpr index_t Banks = get_n_lds_banks(); - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackK(); + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = + Banks * 4 * numeric_traits::PackedSize / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackK(); static_assert(PixelsPerRow % kKPack == 0); constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; @@ -632,10 +732,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() { - using VDataType = remove_cvref_t; - constexpr index_t Banks = get_n_lds_banks(); - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = + Banks * 4 * numeric_traits::PackedSize / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); static_assert(PixelsPerRow % kKPack == 0); constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; @@ -672,10 +773,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { + using KDataType = remove_cvref_t; + // TODO: assume Q is in register // TODO: assume K/V has same data type - constexpr index_t single_smem_size = - GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + constexpr index_t single_smem_size = GetSingleSmemElementSpaceSize() * + sizeof(KDataType) / + numeric_traits::PackedSize; return QXPolicy::template GetSmemSizeQ() + single_smem_size * NumKVLdsBuffers; } @@ -735,7 +839,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy::PackedSize / sizeof(KDataType); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); @@ -966,6 +1071,23 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + CK_TILE_HOST_DEVICE static constexpr auto MakePScaleRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeAScaleBlockTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVScaleRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::MakeBScaleBlockTileDistribution(); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { @@ -980,39 +1102,77 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy>; - auto warp_gemm = [&]() { - if constexpr(get_warp_size() == 64 && - std::is_same_v && - std::is_same_v && - std::is_same_v) - { - static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32); - - return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{}; - } + constexpr auto QScaleEnum = []() { + if constexpr(is_detected{}) + return Problem::QScaleEnum; else - { + return ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE; + }(); + + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) + { + constexpr auto warp_gemm = []() { + static_assert(std::is_same_v == + std::is_same_v); + constexpr auto AttrNumAccess = std::is_same_v + ? WGAttrNumAccessEnum::Single + : WGAttrNumAccessEnum::Double; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true>{}; - } - }(); + true, // TransposeC + false, // SwizzleA + false, + AttrNumAccess>{}; + }(); - using WarpGemm = remove_cvref_t; + using BlockGemmPolicy = BlockGemmMxARegBSmemCRegV1CustomPolicy< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + decltype(warp_gemm)>; - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV2CustomPolicy; - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmMxARegBSmemCRegV1{}; + } + else + { + constexpr auto warp_gemm = []() { + if constexpr(get_warp_size() == 64 && + std::is_same_v && + std::is_same_v && + std::is_same_v && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 32 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) + { + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{}; + } + else + { + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>{}; // TransposeC + } + }(); + + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + decltype(warp_gemm)>; + + return BlockGemmARegBSmemCRegV2{}; + } } }; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index f447ab4452..a0ed2fe9dd 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -23,6 +23,8 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..5dde03912a --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp @@ -0,0 +1,374 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// A scale is block distributed tensor +// B is block window on shared memory +// B scale is block distributed tensor +// C is block distributed tensor +// It supports only warp gemms with transposed C. +// TargetCMPerLane_ controls how many consecutive elements of matrix C are calculated by each lane. +template +struct BlockGemmMxARegBSmemCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t CMPerLane = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane * + WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane; + static constexpr index_t TargetCMPerLane = max(CMPerLane, TargetCMPerLane_); + + static_assert(TargetCMPerLane % CMPerLane == 0); + static constexpr index_t NIterPack = TargetCMPerLane / CMPerLane; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const AScaleBlockTensorTmp& a_scale_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp, + const BScaleBlockTensorTmp& b_scale_block_tensor_tmp) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>); + + static_assert(MPerBlock == ABlockTensorTmp{}.get_lengths()[number<0>{}] && + NPerBlock == BBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + KPerBlock == ABlockTensorTmp{}.get_lengths()[number<1>{}]); + + const index_t iNWarp = get_warp_id() % NWarp; + + // construct A-block-tensor from A-Block-tensor-tmp + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + auto a_scale_block_tensor = + make_static_distributed_tensor>( + MakeAScaleBlockTileDistribution()); + a_scale_block_tensor.get_thread_buffer() = a_scale_block_tensor_tmp.get_thread_buffer(); + + auto b_scale_block_tensor = + make_static_distributed_tensor>( + MakeBScaleBlockTileDistribution()); + b_scale_block_tensor.get_thread_buffer() = b_scale_block_tensor_tmp.get_thread_buffer(); + + // Construct B-warp-window + // Matrix B is shuffled in such a way that each lane calculates TargetCMPerLane consecutive + // elements of matrix C. See MakeBScaleBlockTileDistribution and MakeCBlockTile that shuffle + // B scale and C in the same way. + auto b_warp_window_tmp = [&] { + using Impl = typename WarpGemm::WarpGemmAttribute::Impl; + + constexpr index_t N3 = Impl::kCM1PerLane; + constexpr index_t N2 = TargetCMPerLane / N3; + constexpr index_t N1 = Impl::kCMLane; + constexpr index_t N0 = NPerBlock / (N1 * N2 * N3); + + const auto b_lds_unmerged = transform_tensor_view( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 2, 1, 3>{}, sequence<4>{})); + + const auto b_lds_merged = transform_tensor_view( + b_lds_unmerged, + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tile_window( + b_lds_merged, + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + }(); + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + using AScaleWarpDstr = + remove_cvref_t; + using AScaleWarpTensor = + static_distributed_tensor, + AScaleWarpDstr>; + + using BScaleWarpDstr = + remove_cvref_t; + using BScaleWarpTensor = + static_distributed_tensor, + BScaleWarpDstr>; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto a_scale_warp_y_lengths = + to_sequence(AScaleWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_scale_warp_y_lengths = + to_sequence(BScaleWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_scale_warp_y_index_zeros = + uniform_sequence_gen_t{}; + constexpr auto b_scale_warp_y_index_zeros = + uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + auto b_warp_window = b_warp_window_tmp; + move_tile_window( + b_warp_window, + {nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)}); + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_window); + + BScaleWarpTensor b_scale_warp_tensor; + + b_scale_warp_tensor.get_thread_buffer() = + b_scale_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_scale_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths)); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + AScaleWarpTensor a_scale_warp_tensor; + + a_scale_warp_tensor.get_thread_buffer() = + a_scale_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_scale_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WarpGemm{}.template operator()<0, 0>( + c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + int32_t(a_scale_warp_tensor.get_thread_buffer()[0]), + int32_t(b_scale_warp_tensor.get_thread_buffer()[0])); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp_ = KPerBlock_ / WarpGemm::kK; + + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeAScaleWarpDstrEncoding() + { + using Impl = typename WarpGemm::WarpGemmAttribute::Impl; + + constexpr index_t AScaleMLane = Impl::kAMLane; + constexpr index_t ABScaleKLane = Impl::kABKLane; + constexpr index_t ABScaleKPerLane = Impl::kABKPerLane / Impl::kScaleGranularity; + + return ck_tile::tile_distribution_encoding< + ck_tile::sequence<>, + ck_tile::tuple, + ck_tile::sequence>, + ck_tile::tuple>, + ck_tile::tuple>, + ck_tile::sequence<2>, + ck_tile::sequence<1>>{}; + } + + CK_TILE_DEVICE static constexpr auto MakeBScaleWarpDstrEncoding() + { + using Impl = typename WarpGemm::WarpGemmAttribute::Impl; + + constexpr index_t BScaleNLane = Impl::kBNLane; + constexpr index_t ABScaleKLane = Impl::kABKLane; + constexpr index_t ABScaleKPerLane = Impl::kABKPerLane / Impl::kScaleGranularity; + + return ck_tile::tile_distribution_encoding< + ck_tile::sequence<>, + ck_tile::tuple, + ck_tile::sequence>, + ck_tile::tuple>, + ck_tile::tuple>, + ck_tile::sequence<2>, + ck_tile::sequence<1>>{}; + } + + template + CK_TILE_DEVICE static constexpr auto MakeAScaleBlockTileDistribution() + { + constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp_ = KPerBlock_ / WarpGemm::kK; + + constexpr auto a_scale_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_scale_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_scale_block_outer_dstr_encoding, MakeAScaleWarpDstrEncoding()); + + return make_static_tile_distribution(a_scale_block_dstr_encode); + } + + template + CK_TILE_DEVICE static constexpr auto MakeBScaleBlockTileDistribution() + { + constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp_ = KPerBlock_ / WarpGemm::kK; + + using Impl = typename WarpGemm::WarpGemmAttribute::Impl; + + constexpr index_t ABScaleKLane = Impl::kABKLane; + constexpr index_t ABScaleKPerLane = Impl::kABKPerLane / Impl::kScaleGranularity; + + constexpr auto b_scale_block_dstr_encode = ck_tile::tile_distribution_encoding< + ck_tile::sequence, + ck_tile::tuple, + ck_tile::sequence>, + ck_tile::tuple, ck_tile::sequence<2, 1, 1, 1>>, + ck_tile::tuple, ck_tile::sequence<1, 4, 2, 5>>, + ck_tile::sequence<1, 1, 2, 2>, + ck_tile::sequence<0, 3, 0, 2>>{}; + + return make_static_tile_distribution(b_scale_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + using Impl = typename WarpGemm::WarpGemmAttribute::Impl; + + constexpr auto c_block_dstr_encode = ck_tile::tile_distribution_encoding< + ck_tile::sequence<>, + ck_tile::tuple, + ck_tile::sequence>, + ck_tile::tuple, ck_tile::sequence<2, 1>>, + ck_tile::tuple, ck_tile::sequence<2, 2>>, + ck_tile::sequence<1, 2, 2, 2, 2>, + ck_tile::sequence<0, 0, 3, 4, 5>>{}; + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const AScaleBlockTensorTmp& a_scale_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp, + const BScaleBlockTensorTmp& b_scale_block_tensor_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, + a_block_tensor_tmp, + a_scale_block_tensor_tmp, + b_block_window_tmp, + b_scale_block_tensor_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..d97653f86b --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,36 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmMxARegBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 2f25ae9bf5..f393526de1 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -407,6 +407,12 @@ using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed = WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4, AttrNumAccess>>; +template +using WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, @@ -427,6 +433,36 @@ using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; +template +using WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 21a1dd5ba6..f79741ea96 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -446,6 +446,19 @@ struct WarpGemmAttributeMfmaTransposedCDistribution Impl{}(c_vec, b_vec, a_vec, bool_constant{}); } + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale, + bool_constant = {}) const + { + // swap A and B + Impl{}.template operator()( + c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant{}); + } + // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { @@ -540,6 +553,19 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB Impl{}(c_vec, b_vec, a_vec, bool_constant{}); } + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale, + bool_constant = {}) const + { + // swap A and B + Impl{}.template operator()( + c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant{}); + } + // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index c15e51af13..bc591ae740 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1599,6 +1599,8 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 static constexpr index_t kCM0PerLane = 1; static constexpr index_t kCM1PerLane = 4; + static constexpr index_t kScaleGranularity = 32; + // To get unity scale: 2^(kDefaultScale - 127) = 1.0 static constexpr index_t kDefaultScale = 0x7F7F7F7F; @@ -1683,15 +1685,15 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 }; template -struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4 { static constexpr WGAttrCtlEnum Ctrl = Ctrl_; using ADataType = AType_; using BDataType = BType_; using CDataType = float; - using AVecType = ext_vector_t; - using BVecType = ext_vector_t; + using AVecType = ext_vector_t::PackedSize>; + using BVecType = ext_vector_t::PackedSize>; using CVecType = ext_vector_t; static constexpr index_t kM = 32; @@ -1711,6 +1713,71 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base static constexpr index_t kCM0PerLane = 4; static constexpr index_t kCM1PerLane = 4; + static constexpr index_t kScaleGranularity = 32; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale, + bool_constant = {}) const + { +#if defined(__gfx950__) + auto dtype2conf = [](auto dtype) { + if constexpr(std::is_same_v) + return make_tuple(number<0>{}, int32x8_t{}); + else if constexpr(std::is_same_v) + return make_tuple(number<1>{}, int32x8_t{}); + else if constexpr(std::is_same_v) + return make_tuple(number<2>{}, pk_fp6x32_t{}); + // else if e3m2 => make_tuple(number<3>{}, int32x6_t{}) + else if constexpr(std::is_same_v) + return make_tuple(number<4>{}, int32x4_t{}); + else + static_assert(false, "Unsupported data type for mfma scale"); + }; + auto dtype2code = [&](auto dtype) { return dtype2conf(dtype)(number<0>{}); }; + auto dtype2vec = [&](auto dtype) { return dtype2conf(dtype)(number<1>{}); }; + auto arg256 = [&](auto x) { + if constexpr(sizeof(x) == 16) + return int32x8_t{x[0], x[1], x[2], x[3], 0, 0, 0, 0}; + else if constexpr(sizeof(x) == 24) + return int32x8_t{x[0], x[1], x[2], x[3], x[4], x[5], 0, 0}; + else if constexpr(sizeof(x) == 32) + return x; + else + static_assert(false, "Unexpected vector size for mfma scale"); + }; + + auto arg_a = bit_cast(a_vec); + auto arg_b = bit_cast(b_vec); + constexpr int cbsz = decltype(dtype2code(ADataType{}))::value; + constexpr int blgp = decltype(dtype2code(BDataType{}))::value; + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg256(arg_a), arg256(arg_b), c_vec, cbsz, blgp, opselA, a_scale, opselB, b_scale); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_scale; +#endif + } + + // c_vec = a_vec * b_vec + template + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, + const int32_t& a_scale, + const BVecType& b_vec, + const int32_t& b_scale) const + { + CVecType c_vec{0.f}; + operator()(c_vec, a_vec, a_scale, b_vec, b_scale); + return c_vec; + } + // c_vec += a_vec * b_vec template CK_TILE_DEVICE void operator()(CVecType& c_vec, @@ -1718,67 +1785,31 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base const BVecType& b_vec, bool_constant = {}) const { - //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, - // opsel, scale_b) -#if defined(__gfx950__) - if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); -#else - ck_tile::ignore = c_vec; - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; -#endif + operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0); } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx950__) - if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); - else if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); - else if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); - else if constexpr(std::is_same_v && std::is_same_v) - return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( - a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); -#else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; - return CVecType{0.f}; -#endif + return operator()<0, 0>(a_vec, 0, b_vec, 0); } }; template using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4; template using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4; template using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4; template using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 = - WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4; // int8 template diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 6c25050c9c..081ff5150d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -130,6 +130,8 @@ template struct Dispatcher struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed; }; + template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; @@ -143,6 +145,13 @@ template<> struct Dispatcher struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed; }; +template struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed; }; + +template struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed; }; + template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; }; @@ -152,7 +161,6 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; - //WMMA cases template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_f8_f8; }; template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 017391549a..df4818a7c5 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -53,7 +53,9 @@ set(REGRESSION_TESTS test_ck_tile_fmha_fwd_fp32 test_ck_tile_fmha_fwd_bf16 test_ck_tile_fmha_fwd_fp16 - test_ck_tile_fmha_fwd_fp8 + test_ck_tile_fmha_fwd_fp8bf16 + test_ck_tile_fmha_fwd_mxfp8 + test_ck_tile_fmha_fwd_mxfp4 test_ck_tile_streamk_extended ) diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index bb5aad2a73..d296c40cc3 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -7,15 +7,17 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") set(TEST_NAME "test_ck_tile_fmha") function(add_gtest_fwd test_group) - if((GPU_TARGETS MATCHES "gfx90a" AND CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx9[45]|gfx12") - set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32") - elseif((GPU_TARGETS MATCHES "gfx90a" AND NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx11") - set(V_TYPES "fp16" "bf16" "fp32") + set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32" "mxfp8" "mxfp4") + if(GPU_TARGETS MATCHES "gfx908|gfx90a" AND NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) + # fp8 instances are built for all gfx9, do not test on archs without hardware support + list(REMOVE_ITEM V_TYPES "fp8bf16") endif() set(CPP_TYPE_fp16 "FmhaFwdFp16") set(CPP_TYPE_bf16 "FmhaFwdBf16") set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16") set(CPP_TYPE_fp32 "FmhaFwdFp32") + set(CPP_TYPE_mxfp8 "FmhaFwdMxFp8") + set(CPP_TYPE_mxfp4 "FmhaFwdMxFp4") set(sources) if(TARGET ${FMHA_FWD_INSTANCES}) diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index c59ee7a67d..c2a90360d9 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -1,6 +1,11 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include +#include +#include +#include + #include "example/ck_tile/01_fmha/fmha_fwd.hpp" #include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" @@ -42,6 +47,7 @@ struct TestConfigs static constexpr auto qscale_str = "n"; static constexpr bool def_lse = true; static constexpr bool def_is_v_rowmajor = true; + static constexpr auto init_method = "uf"; static int adjust_seqlen(int seqlen) { return seqlen; } }; @@ -57,11 +63,45 @@ struct TestConfigs static constexpr auto qscale_str = "pt"; static constexpr bool def_lse = false; static constexpr bool def_is_v_rowmajor = true; + static constexpr auto init_method = "3"; // When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests: // return ck_tile::integer_least_multiple(seqlen, 128); static int adjust_seqlen(int seqlen) { return seqlen; } }; +template <> +struct TestConfigs +{ + static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}}; + static constexpr auto SplitKVHDimValues = std::array, 0>{}; + static constexpr auto AppendKVHDimValues = std::array, 0>{}; + static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; + static constexpr auto IsVRowmajorValues = std::array{false}; + static constexpr auto qscale_str = "mx"; + static constexpr bool def_lse = true; + static constexpr bool def_is_v_rowmajor = false; + static constexpr auto init_method = "3"; + static int adjust_seqlen(int seqlen) { return seqlen; } +}; + +template <> +struct TestConfigs +{ + static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}}; + static constexpr auto SplitKVHDimValues = std::array, 0>{}; + static constexpr auto AppendKVHDimValues = std::array, 0>{}; + static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; + static constexpr auto IsVRowmajorValues = std::array{false}; + static constexpr auto qscale_str = "mx"; + static constexpr bool def_lse = true; + static constexpr bool def_is_v_rowmajor = false; + static constexpr auto init_method = "3"; + static int adjust_seqlen(int seqlen) + { + return seqlen < 0 ? seqlen : ck_tile::integer_least_multiple(seqlen, 2); + } +}; + template <> struct TestConfigs { @@ -81,6 +121,7 @@ struct TestConfigs static constexpr auto qscale_str = "n"; static constexpr bool def_lse = true; static constexpr bool def_is_v_rowmajor = true; + static constexpr auto init_method = "uf"; static int adjust_seqlen(int seqlen) { return seqlen; } }; @@ -92,8 +133,8 @@ static auto IsVRowmajorValues = ValuesIn(TestConfigs::IsVRowm constexpr static auto qscale_str = TestConfigs::qscale_str; constexpr bool def_lse = TestConfigs::def_lse; constexpr bool def_is_v_rowmajor = TestConfigs::def_is_v_rowmajor; +constexpr auto init_method = TestConfigs::init_method; int adjust_seqlen(int seqlen) { return TestConfigs::adjust_seqlen(seqlen); } -constexpr auto init_method = "uf"; // Random seed used for initializing input tensors. 0 for non-deterministic seed CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456) @@ -901,12 +942,6 @@ using PaddingParam = std::tuple; // mask_str -// Ensure headers for containers / algorithms used in padding param builder. -#include -#include -#include -#include - class PaddingCases : public TestWithParam { }; @@ -918,6 +953,12 @@ static std::vector BuildPaddingParams() { std::vector params; + if constexpr(ck_tile::is_any_of:: + value) + { + return params; + } + // mask variants to cover const std::vector mask_variants{"0", "t:50,64", "b:32,40"}; const std::vector mask_variants_reduced{"0", "t:50,64"}; // used for trimmed sets @@ -1106,15 +1147,10 @@ static std::vector BuildPaddingParams() static const std::vector kPaddingParams = BuildPaddingParams(); -INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams)); +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, PaddingCases, ValuesIn(kPaddingParams)); TEST_P(PaddingCases, DataTypeConfig) { - if constexpr(std::is_same_v) - { - GTEST_SKIP() << "Skip for fp8"; - } - auto [mode, batch, nhead,