From 11e6485457d61b1d3da1b680d828e405c5031460 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" <210906412+assistant-librarian[bot]@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:55:50 -0800 Subject: [PATCH] [CK_TILE] Extend support of mix precision microscaling BQuant (#4267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Supported types combinations using BQuant=e8m0: - A=bf16 - B=bf16,bf8,fp4 Summary: - remove usage of `pk_fp4_raw_t`: consistent with other implementations and avoid taking into account of the packed size explicitly. In general, the raw type should not be used because CK Tile internally takes care of the PackedSize, so using the raw type adds unnecessary complexity to the implementation - handle microscaling by checking for `e8m0` type for BQuant (previous implementation was inconsistent) - add support for scaling instructions in `DequantPack8` - mx pipeline: - extend existing pipeline to support different B types - add support to scale and cast before writing to LDS or after reading from LDS (this can be defined in the `Problem` by the user) - block gemm: - mx pipeline is now using block gemm BQuant - block gemm BQuant can now load from LDS and apply scale and then call block gemm universal operator. This adds new functionalities and remove code duplication - warp gemm: - add case to support 128bit ds_read/write for both A and B when A=16bit and B=8bit - add examples and tests: note that some tests for bf16/fp4 already existed but were removed during previous tests refactoring. I added them again and other relevant tests for new types combinations ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered --- 🔁 Imported from [ROCm/composable_kernel#3689](https://github.com/ROCm/composable_kernel/pull/3689) 🧑‍💻 Originally authored by @EnricoDeg --------- Co-authored-by: Enrico Degregori Co-authored-by: systems-assistant[bot] Co-authored-by: Thomas Ning Co-authored-by: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../38_block_scale_gemm/CMakeLists.txt | 4 +- example/ck_tile/38_block_scale_gemm/README.md | 2 +- .../gemm_bquant_quantgrouped_mx_bf16bf16.cpp | 35 ++ .../gemm_bquant_quantgrouped_mx_bf16bf8.cpp | 34 ++ ...> gemm_bquant_quantgrouped_mx_bf16fp4.cpp} | 18 +- .../38_block_scale_gemm/gemm_quant.cpp | 2 +- .../38_block_scale_gemm/gemm_utils.hpp | 20 +- .../run_gemm_quant_example.inc | 77 +-- .../arch/amd_buffer_addressing_builtins.hpp | 6 + .../core/tensor/load_tile_transpose.hpp | 87 +-- .../core/tensor/static_distributed_tensor.hpp | 2 +- include/ck_tile/core/tensor/tensor_view.hpp | 7 +- include/ck_tile/core/tensor/tile_window.hpp | 8 +- .../ck_tile/host/reference/reference_gemm.hpp | 65 ++- include/ck_tile/ops/common/utils.hpp | 1 + .../unary_element_wise_operation.hpp | 298 +++++++++++ .../ops/epilogue/cshuffle_epilogue.hpp | 2 +- .../block/block_universal_gemm_as_bs_cr.hpp | 3 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 37 +- .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 6 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 57 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 27 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 174 ++++-- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 21 +- include/ck_tile/ops/gemm_quant.hpp | 6 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 242 ++++++--- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 49 +- ...emm_microscale_pipeline_ag_bg_cr_base.hpp} | 14 +- ...mm_microscale_pipeline_ag_bg_cr_policy.hpp | 296 +++++++++++ ... gemm_microscale_pipeline_ag_bg_cr_v3.hpp} | 499 +++++++++++------- .../gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp | 140 ----- .../pipeline/gemm_quant_pipeline_problem.hpp | 23 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 41 ++ .../gemm_block_scale/test_gemm_quant_base.hpp | 2 +- .../test_gemm_quant_bquant_1d_128.cpp | 6 +- ...emm_quant_bquant_microscale_ccr_1d_128.cpp | 41 ++ ...gemm_quant_bquant_microscale_ccr_1d_64.cpp | 45 ++ ...emm_quant_bquant_microscale_crr_1d_128.cpp | 42 ++ ...gemm_quant_bquant_microscale_crr_1d_64.cpp | 42 ++ ...emm_quant_bquant_microscale_rcr_1d_128.cpp | 51 ++ ...gemm_quant_bquant_microscale_rcr_1d_64.cpp | 51 ++ ...emm_quant_bquant_microscale_rrr_1d_128.cpp | 43 ++ ...gemm_quant_bquant_microscale_rrr_1d_64.cpp | 43 ++ .../test_gemm_quant_fixtures.hpp | 75 ++- 44 files changed, 2061 insertions(+), 683 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_bf16mxfp4.cpp => gemm_bquant_quantgrouped_mx_bf16fp4.cpp} (67%) rename include/ck_tile/ops/gemm_quant/pipeline/{gemm_mxfp4_pipeline_ag_bg_cr_base.hpp => gemm_microscale_pipeline_ag_bg_cr_base.hpp} (80%) create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp rename include/ck_tile/ops/gemm_quant/pipeline/{gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp => gemm_microscale_pipeline_ag_bg_cr_v3.hpp} (60%) delete mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_128.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_64.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_128.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_64.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_128.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_64.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_128.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_64.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index e7a218152d..97e719177f 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -20,7 +20,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf16mxfp4.cpp + gemm_bquant_quantgrouped_mx_bf16fp4.cpp + gemm_bquant_quantgrouped_mx_bf16bf8.cpp + gemm_bquant_quantgrouped_mx_bf16bf16.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index eb36ae5800..accac6f083 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp new file mode 100644 index 0000000000..e1a64c8656 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +static auto _ = []() { + auto& lut = get_kernel_lut(); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp new file mode 100644 index 0000000000..0eb2a0ce34 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +using GemmConfig = GemmConfigMixedPrecision; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type(arg_parser); + +static auto _ = []() { + auto& lut = get_kernel_lut(); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp similarity index 67% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp index b8eb670135..1f48609a1f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp @@ -6,33 +6,33 @@ template using GemmConfig = GemmConfigQuantPrefill; -#define RUN_GEMM_EXAMPLE_PREC_TYPE \ - run_gemm_example_prec_type, \ - TypeConfig, \ - QuantGroupSize, \ +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); static auto _ = []() { auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); + ck_tile::e8m0_t>{}); lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index cc4302a992..dc4d1ad814 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8, fp4") + " mxbf16bf16, mxbf16bf8, mxbf16fp4 or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 9a51c786b6..db3f4c6e17 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -45,7 +45,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_value) { using ComputeType = std::conditional_t< - std::is_same_v, + std::is_same_v, ADataType, std::conditional_t>; // Calculate thresholds @@ -278,6 +278,24 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill static constexpr bool TransposeC = true; }; +// Used for A=16bit and B=8bit. The warp tile has KPack=16 +// Matrix A: Vectorsize = 8, KPack=16 -> LDS read/write vectorsize = 8 (128bit) +// Matrix B: Vectorsize = 16, KPack=16 -> LDS read/write vectorsize = 16 (128bit) +struct GemmConfigMixedPrecision : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + template struct GemmConfigEightWarps : public GemmConfigABQuantPrefill { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index c2954f3bf5..da14f85c2c 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -108,6 +108,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; + constexpr auto b_cast_policy = + std::is_same_v + ? ck_tile::CastPolicy::BeforeLDSWrite + : ck_tile::CastPolicy::AfterLDSRead; // row-col and tensor quants use the regular pipeline, A/B/AB quants use their own using PipelineProblem = std::conditional_t< @@ -150,7 +154,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ComputeDataType, GemmConfig::Scheduler, has_hot_loop_v, - tail_number_v>, + tail_number_v, + b_cast_policy>, ck_tile::GemmABQuantPipelineProblem, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + std::conditional_t, + ck_tile::MicroscaleGemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; using ABQuantPipeline = std::conditional_t< eight_warps, @@ -257,11 +261,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? args.K / 2 - : args.K, - args.N, - args.stride_B, - is_row_major(BLayout{}))); + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -495,11 +495,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -531,11 +527,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -575,18 +568,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{ + range_min, range_max, fill_seed(gen)}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(*bq_tensor_ptr, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -850,18 +856,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - if constexpr(std::is_same_v) - ck_tile::reference_mxfp4gemm_quant( + if constexpr(std::is_same_v) + ck_tile::reference_mx_gemm_bquant( a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant) && + std::is_same_v) && GemmConfig::PreshuffleB) { throw std::runtime_error( diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index c165cacba2..c74068c03c 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2865,6 +2865,12 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); } + else if constexpr(std::is_same_v, ck_tile::pk_fp4_t>) + { + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t; + auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_); + return bit_cast>(__builtin_amdgcn_ds_read_tr4_b64_v2i32(lds_ptr)); + } else { static_assert(false, "not implemented"); diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index 2d71a9cfab..462a9cf4ab 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -50,60 +50,61 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix::valu template struct DefaultTranspose { - template - struct Quad16 + template + struct Quad { static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, "LaneGroupSize must be 64, 32, or 16"); - using InputEncoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<2>>; - using OutputEncoding = - tile_distribution_encoding, - tuple, sequence<4>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; + // The tile is defined by the LaneGroupSize, which defines the number of lanes in the M/N + // dimensions for the MMA instruction defined by warp gemm. + // The LaneGroupSize is subdivided into groups of 16 (finer granularity of MMA + // instructions), we define these as major subtiles. Each of these major subtile is divided + // into minor subtiles which group the lanes exchanging data during the transpose Example + // LaneGroupSize = 16, 16 bit type: + // - There is 1 group of 16 lanes (1 major subtile) + // - Each major subtile is divided into 4 minor subtiles of (4x4) -> 4 lanes transpose + // the minor subtile and each lane holds 4 elements + + // all load transpose instructions use 64 bit right now + static constexpr index_t InstructionBits = 64; + // Subtile major dimension is fixed + static constexpr index_t SubtileMajorDimension = 16; + // Number of subtile major + static constexpr index_t NumSubtilesMajor = LaneGroupSize / 16; + // number of elements loaded by each lane with single instruction, but also number + // of consecutive lanes in a subtile. Subtile is squared (NLanes x NElementsPerLane) + static constexpr index_t SubtileMinorDimension = InstructionBits / NumBitType; + // Number of subtiles minor inside each subtile major + static constexpr index_t NumSubtilesMinor = 16 / SubtileMinorDimension; + + using InputEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<2>>; + + using OutputEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; }; - template - struct Quad8 - { - static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16, - "LaneGroupSize must be 64, 32, or 16"); - using InputEncoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<2>>; - - using OutputEncoding = - tile_distribution_encoding, - tuple, sequence<8>>, - tuple>, - tuple>, - sequence<2>, - sequence<0>>; - }; + static constexpr index_t PackedSize = numeric_traits>::PackedSize; + static constexpr index_t NumBitsDataType = (sizeof(DataType) * 8) / PackedSize; // Select based on data size template - using QuadInputEncoding = std::conditional_t::InputEncoding, - typename Quad8::InputEncoding>; + using QuadInputEncoding = typename Quad::InputEncoding; template - using QuadOutputEncoding = std::conditional_t::OutputEncoding, - typename Quad8::OutputEncoding>; + using QuadOutputEncoding = typename Quad::OutputEncoding; // Always swap last two dimensions static constexpr auto transpose_dims = sequence<1, 0>{}; diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index bdd81dae07..787c17e1be 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -78,7 +78,7 @@ struct static_distributed_tensor constexpr auto sliced_thread_tensor_desc = make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); - thread_buffer + thread_buffer sliced_thread_data; static_ford>{}([&](auto idx) { diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 95d66b66ed..1d008b495b 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -287,8 +287,8 @@ struct tensor_view get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const { return buf_.template transpose_get( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); } @@ -303,7 +303,8 @@ struct tensor_view bool is_valid_element // flag ) const { - return buf_.template transpose_get(coord.get_offset(), linear_offset, is_valid_element); + return buf_.template transpose_get( + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element); } // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index ba7eeb1936..2f2fe12f42 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -736,7 +736,7 @@ struct tile_window_with_static_distribution .template get_transpose_vectorized_elements( bottom_tensor_thread_coord, offset); // write into distributed tensor - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto orig_idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -747,10 +747,12 @@ struct tile_window_with_static_distribution constexpr auto grouped_idx_ys = group_func(orig_idx_ys); constexpr index_t linear_distributed_index = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys) / + Traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value + .template get_as()[j / Traits::PackedSize]; }); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index da6b074b98..b6d7fbf521 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -388,49 +388,56 @@ template -CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, - const HostTensor& q, - const HostTensor& b_k_n, - HostTensor& c_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) +CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const std::size_t M = a_m_k.get_length(0); const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); auto f_mn = [&](auto m, auto n) { - AccDataType v_acc = 0; - AccDataType pasual = 0; - for(std::size_t k = 0; k < (K / 2); k++) - { - using ComputeType = float; - auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; - ComputeType v_a_0, v_a_1; - ComputeType v_b_0, v_b_1; + AccDataType v_acc = 0; + using ComputeType = float; + ComputeType v_a; + ComputeType v_b; - v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); - v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); - - if constexpr(std::is_same_v) + auto load_b = [&](std::size_t k) -> AccDataType { + if constexpr(std::is_same_v) { - auto b_pack = type_convert(b_element_op(b_k_n(k, n))); - auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); - - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - - v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; - v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + const auto b_pack = type_convert(b_element_op(b_k_n(k, n))); + if constexpr(std::is_same_v) + { + return (n & 1) ? type_convert(b_pack.unpack(number<1>{})) + : type_convert(b_pack.unpack(number<0>{})); + } + else + { + return (k & 1) ? type_convert(b_pack.unpack(number<1>{})) + : type_convert(b_pack.unpack(number<0>{})); + } } + else + { + return ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + }; - pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; - v_acc += pasual; + for(std::size_t k = 0; k < K; k++) + { + const auto b_scale = type_convert(q(k / QuantGroupSize::kK, n)); + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + v_b = load_b(k) * b_scale; + v_acc += v_a * v_b; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); }; diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 4a30e3af16..6c1287486f 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -24,6 +24,7 @@ template <> struct DataTypeTraits { static constexpr const char * nam template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp6x16"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "e8m0"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 3f58eceb33..4ad699629c 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -359,6 +359,260 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q) } #endif +CK_TILE_HOST_DEVICE bf16x8_t bf8x8_to_bf16x8_scale(const bf8x8_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + bf8_t elements[4]; + } input; + + union + { + bf16x2_t vec; + bf16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE bf16x8_t fp8x8_to_bf16x8_scale(const fp8x8_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + fp8_t elements[4]; + } input; + + union + { + bf16x2_t vec; + bf16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const float& scale) +{ + fp16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + fp8_t elements[4]; + } input; + + union + { + fp16x2_t vec; + fp16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE fp16x8_t bf8x8_to_fp16x8_scale(const bf8x8_t& src, const float& scale) +{ + fp16x8_t y; +#if defined(__gfx950__) + constexpr index_t USE_BOTTOM = 0; + constexpr index_t USE_TOP = 1; + + auto convert_quartet = [&](index_t src_offset, index_t dst_offset) { + union + { + uint32_t packed; + bf8_t elements[4]; + } input; + + union + { + fp16x2_t vec; + fp16_t elements[2]; + } output; + + input.elements[0] = src[src_offset]; + input.elements[1] = src[src_offset + 1]; + input.elements[2] = src[src_offset + 2]; + input.elements[3] = src[src_offset + 3]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_BOTTOM); + y[dst_offset] = output.elements[0]; + y[dst_offset + 1] = output.elements[1]; + + output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_TOP); + y[dst_offset + 2] = output.elements[0]; + y[dst_offset + 3] = output.elements[1]; + }; + + convert_quartet(0, 0); + convert_quartet(4, 4); +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE bf16x8_t fp4x4_to_bf16x8_scale(const pk_fp4x4_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t u32; + pk_fp4x4_t pf4; + } cvt; + + constexpr index_t USE_BYTE_0 = 0; + constexpr index_t USE_BYTE_1 = 1; + constexpr index_t USE_BYTE_2 = 2; + constexpr index_t USE_BYTE_3 = 3; + + cvt.pf4 = src; + bf16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_0); + bf16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_1); + bf16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_2); + bf16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_3); + + y[0] = y0[0]; + y[1] = y0[1]; + y[2] = y1[0]; + y[3] = y1[1]; + y[4] = y2[0]; + y[5] = y2[1]; + y[6] = y3[0]; + y[7] = y3[1]; +#else + static_for<0, 4, 1>{}([&](auto i) { + auto yi = pk_fp4_to_bf16x2(src[i.value], scale); + y[2 * i.value] = yi[0]; + y[2 * i.value + 1] = yi[1]; + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE fp16x8_t fp4x4_to_fp16x8_scale(const pk_fp4x4_t& src, const float& scale) +{ + fp16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t u32; + pk_fp4x4_t pf4; + } cvt; + + constexpr index_t USE_BYTE_0 = 0; + constexpr index_t USE_BYTE_1 = 1; + constexpr index_t USE_BYTE_2 = 2; + constexpr index_t USE_BYTE_3 = 3; + + cvt.pf4 = src; + fp16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_0); + fp16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_1); + fp16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_2); + fp16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_3); + + y[0] = y0[0]; + y[1] = y0[1]; + y[2] = y1[0]; + y[3] = y1[1]; + y[4] = y2[0]; + y[5] = y2[1]; + y[6] = y3[0]; + y[7] = y3[1]; +#else + static_for<0, 4, 1>{}([&](auto i) { + auto yi = pk_fp4_to_fp16x2(src[i.value], scale); + y[2 * i.value] = yi[0]; + y[2 * i.value + 1] = yi[1]; + }); +#endif + return y; +} + struct PassThroughPack8 { static constexpr const char* name = "PassThroughPack8"; @@ -437,6 +691,50 @@ struct DequantPack8 y.hi = i4_to_half4_scale(bit_cast(x) >> 8, z); } + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const pk_fp4x4_t& x, const float& z) const + { + y = fp4x4_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const pk_fp4x4_t& x, const float& z) const + { + y = fp4x4_to_fp16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf8x8_t& x, const float& z) const + { + y = bf8x8_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const fp8x8_t& x, const float& z) const + { + y = fp8x8_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const fp8x8_t& x, const float& z) const + { + y = fp8x8_to_fp16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const bf8x8_t& x, const float& z) const + { + y = bf8x8_to_fp16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf16x8_t& x, const float& z) const + { + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(x[i.value]) * z); + }); + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index b31f8ba02a..7ebfa412f7 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -99,7 +99,7 @@ struct CShuffleEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t || std::is_same_v || - std::is_same_v, + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd51..7f34ae24bb 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -97,7 +97,8 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4973d9c941..7cc14ecc39 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -20,8 +20,23 @@ struct GemmPipelineAgBgCrImplBase using ADataType = remove_cvref_t{}, AsDataType>>; using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = - std::conditional_t, ADataType, BInDataType>; + + template + using has_bcastpolicy_type = decltype(T::BCastPolicy); + + static constexpr bool IsBCastPolicyBeforeLDSWrite = [] { + if constexpr(is_detected{}) + { + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + } + else + { + return false; + } + }(); + + using BDataType = std::conditional_t; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; @@ -226,6 +241,12 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view, const ALdsLoadTileDistr&) const { + // with pk_int4_t load transpose the LDS type is always BDataType + using ADataTypeLDS = + std::conditional_t, + typename Problem::BDataType, + typename Problem::ADataType>; + auto a_lds_shape = []() { if constexpr(is_a_load_tr) return make_tuple(number{}, number{}); @@ -238,9 +259,8 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_load_tile_distr = []() { if constexpr(is_a_load_tr) return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename ALdsLoadTileDistr::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); + typename InputTileDistributionTraits::TransposedDstrEncode{}); else return ALdsLoadTileDistr{}; }(); @@ -313,10 +333,9 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLdsDataType = std::conditional_t; auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 987704e433..f9d82f8eb4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -10,6 +10,12 @@ namespace ck_tile { +enum struct CastPolicy +{ + BeforeLDSWrite, + AfterLDSRead, +}; + enum struct GemmPipelineScheduler { Default, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8074994fdd..cb112a11a7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -80,6 +80,21 @@ struct UniversalGemmBasePolicy static constexpr bool is_b_load_tr = false; #endif + template + using has_bcastpolicy_type = decltype(T::BCastPolicy); + + template + static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { + if constexpr(is_detected{}) + { + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + } + else + { + return false; + } + }(); + static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -305,11 +320,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLayout = remove_cvref_t; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -589,15 +604,14 @@ struct UniversalGemmBasePolicy CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; - using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; using BLayout = remove_cvref_t{}, BsLayout>>; - using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; if constexpr(Problem::FixedVectorSize) { @@ -739,13 +753,13 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - using BDataType = remove_cvref_t; - constexpr index_t KPerBlock = std::is_same_v - ? Problem::BlockGemmShape::kK / 2 - : Problem::BlockGemmShape::kK; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // If we cast before writing to LDS, the vectorsize is defined by the A type + // since the assumption is that A type is going to be the B LDS type + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; constexpr index_t VecLoadSize = - std::is_same_v - ? 4 + IsBCastPolicyBeforeLDSWrite + ? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA()) : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using BLayout = remove_cvref_t< @@ -855,10 +869,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); @@ -900,7 +914,8 @@ struct UniversalGemmPipelineAgBgCrPolicy using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 0051242475..f3fa99304c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -185,16 +185,35 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< WarpGemmAttributeMfma, - AttrNumAccess>>; + AttrNumAccessA, + AttrNumAccessB>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 2, + AttrNumAccessA, + AttrNumAccessB>>; #else -template +template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2, - AttrNumAccess>>; + AttrNumAccessA>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 4, + AttrNumAccessA, + AttrNumAccessB>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl +struct get_wgattr_num_access +{ + private: + static constexpr index_t getAccesses() + { + if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Single) + { + return 1; + } + else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Double) + { + return 2; + } + else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Quad) + { + return 4; + } + else + { + static_assert(false, "unsupported AttrNumAccess"); + return 0; + } + } + + public: + static constexpr auto value = getAccesses(); +}; + template + WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single, + WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_> struct WarpGemmAttributeMfma { - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; - static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccessA = AttrNumAccessA_; + static constexpr auto AttrNumAccessAV = get_wgattr_num_access::value; + static constexpr auto AttrNumAccessB = AttrNumAccessB_; + static constexpr auto AttrNumAccessBV = get_wgattr_num_access::value; + + static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB; using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -44,12 +78,13 @@ struct WarpGemmAttributeMfma static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - template + template static constexpr auto get_warp_dstr_encoding() { - static_assert(kKPerThread % AttrNumAccessV == 0, + static_assert(kKPerThread % AttrNumAccessV_ == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(AttrNumAccessV_ == 1) + { return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -57,18 +92,48 @@ struct WarpGemmAttributeMfma tuple>, sequence<2>, sequence<1>>{}; + } else - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 2>, - sequence<0, 2>>{}; + { + // AttrNumAccess splits the kABKPerLane + // We can split them but still have them contiguous (packed) or have them interleaved. + // The reason to split the dimension but still have it packed is to match load transpose + // encoding when A and B use different AttrNumAccess (they have different types in LDS) + // Example + // A: 16bit, B: 8bit + // Load transpose B: lane0 -> K=0..7 (only 1 instruction) + // Load transpose A: lane0 -> K=0..3 first instruction, K=4..7 second instruction + // In this way the data in register are consistent between A and B + if constexpr(UsePackNumAccess) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<1, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -121,14 +186,19 @@ struct WarpGemmAttributeMfma template + WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single, + WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_> struct WarpGemmAttributeMfmaIterateK { static_assert(kKIter > 0, "wrong!"); - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; - static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccessA = AttrNumAccessA_; + static constexpr auto AttrNumAccessAV = get_wgattr_num_access::value; + static constexpr auto AttrNumAccessB = AttrNumAccessB_; + static constexpr auto AttrNumAccessBV = get_wgattr_num_access::value; + + static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB; using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; @@ -151,14 +221,15 @@ struct WarpGemmAttributeMfmaIterateK static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - template + template CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() { if constexpr(kMNBlock == 1 && kNMBlock == 1) { - static_assert(kKPerThread % AttrNumAccessV == 0, + static_assert(kKPerThread % AttrNumAccessV_ == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(AttrNumAccessV_ == 1) + { return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -166,21 +237,40 @@ struct WarpGemmAttributeMfmaIterateK tuple>, sequence<2>, sequence<1>>{}; + } else - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 2>, - sequence<0, 2>>{}; + { + if constexpr(UsePackNumAccess) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<1, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } } else if constexpr(kMNBlock == 1 && 1 < kNMBlock) { - static_assert(AttrNumAccessV == 1, + static_assert(AttrNumAccessV_ == 1, "Multiple access is not supported when using multi-block"); // each M/N blocks share the same data return tile_distribution_encoding< @@ -193,7 +283,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(1 < kMNBlock && kNMBlock == 1) { - static_assert(AttrNumAccessV == 1, + static_assert(AttrNumAccessV_ == 1, "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< @@ -245,10 +335,14 @@ struct WarpGemmAttributeMfmaIterateK } } - using AWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index f9a988a923..21360874fb 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -24,9 +24,10 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccessA = ESingle, + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA> struct Dispatcher; // clang-format off @@ -78,6 +79,10 @@ template<> struct Dispatcher { using template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; @@ -166,9 +171,10 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false, + WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single, + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA> using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // AType, BType, @@ -179,6 +185,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // TransposeC, SwizzleA, UseStructuredSparsity, - AttrNumAccess>::Type; + AttrNumAccessA, + AttrNumAccessB>::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 91a9521c4f..c2fe66ea5d 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -24,9 +24,9 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9d711c4862..3af7177365 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -101,20 +102,33 @@ struct BQuantBlockUniversalGemmAsBsCr // 2. bf8, bf8, fp32 -> f32 // 3. i4, fp8, (fp8/fp32) -> f32 // 4. i4, bf8, (fp8/fp32) -> f32 - static_assert((std::is_same_v || std::is_same_v) && - (std::is_same_v || std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v) && - std::is_same_v); + // 5. bf16, (bf16/bf8/fp8/fp4), e8m0 -> f32 + // 6. fp16, (fp16/fp8/bf8/fp4), e8m0 -> f32 + static_assert( + is_any_of::value && + is_any_of::value && + is_any_of::value && + is_any_of::value && + std::is_same_v); static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t KPack = WarpGemm::kKPerThread; static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + + template + using has_bcastpolicy_type = decltype(T::BCastPolicy); + + static constexpr bool IsBCastPolicyBeforeLDSWrite = [] { + if constexpr(is_detected{}) + { + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + } + else + { + return false; + } + }(); }; public: @@ -127,9 +141,12 @@ struct BQuantBlockUniversalGemmAsBsCr using CDataType = remove_cvref_t; // BDataType gets converted from PkInt4 during loading + // OverrideBDataType is only used when BCastPolicy is CastBeforeLDSWrite for microscale. + // In that case we use ADataType using OverrideBDataType = std::conditional_t< - std::is_same_v && - std::is_same_v, + (std::is_same_v && + std::is_same_v) || + Traits::IsBCastPolicyBeforeLDSWrite, ADataType, BDataType>; @@ -176,57 +193,17 @@ struct BQuantBlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + // Use gemm universal block distribution encoding instead of duplicating it + using BlockGemmBase = BlockUniversalGemmAsBsCr; + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; + return BlockGemmBase::MakeABlockDistributionEncode(); } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; + return BlockGemmBase::MakeBBlockDistributionEncode(); } private: @@ -235,20 +212,24 @@ struct BQuantBlockUniversalGemmAsBsCr { }; + using BlockGemmImplBase = typename BlockUniversalGemmAsBsCr:: + template BlockGemmImpl; + template - struct BlockGemmImpl + struct BlockGemmImpl : public BlockGemmImplBase { - static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; - static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BlockGemmImplBase::a_warp_tile_; + using BlockGemmImplBase::b_warp_tile_; + using BlockGemmImplBase::BLdsTileDistr; + // If we apply scale while reading from LDS, then we can use the operator() from + // BlockUniversalGemmAsBsCr + using BlockGemmImplBase::operator(); - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - ALdsTile a_warp_tile_; - BLdsTile b_warp_tile_; + // static distributed tensor with LDS type + using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + BTypeTile b_warp_tile_lds_; + // Load from LDS (assumption is that the scale will be applied in the block gemm) template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + const BQRegBlockTile& bq_block_tensor, + bool_constant = {}, + bool_constant = {}) + { + // Load tile from LDS + + // Do not use load_int4_tile here because it will have support to cast from fp4 to + // compute type, while here we want to only load from LDS and then apply the scale + // and cast later + if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + + if constexpr(BLoadTranspose) + { + b_warp_tile_lds_ = load_tile_transpose(b_block_window); + } + else + { + load_tile(b_warp_tile_lds_, b_block_window); + } + + // Apply scale and cast + using BDataTypeRaw = + std::conditional_t, pk_fp4_t::type, BDataType>; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size; + constexpr index_t thread_buffer_size = nelements / UnaryOpSize_; + const element_wise::DequantPack8 elementwise_op{}; + using SrcVectorRawType = ext_vector_t; + using DstVectorType = ext_vector_t; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // B scale register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return ((nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN) * + Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + // Get B scale from thread buffer + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); + + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // Thread buffers + using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + + BWarpThreadBuffer b_warp_thread_buffer; + BLDSThreadBuffer b_lds_thread_buffer; + + // Load thread buffer from tile (LDS type) + b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Apply scale to B thread buffer and cast + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op( + b_warp_thread_buffer.template get_as()(i), + b_lds_thread_buffer.template get_as()[i], + b_scale_f); + }); + + // Store B thread buffer to tile (MMA type) + b_warp_tile_.set_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), + b_warp_thread_buffer); + }); + }); + }); + } + // C += A * B template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + BQRegBlockTile bq_block_tile, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + block_gemm_impl_.LocalPrefetch( + a_block_window, b_block_window, bq_block_tile, a_load_tr, b_load_tr); + } + // C += A * B + // Apply scale after MMA template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + } + private: BlockGemmImpl block_gemm_impl_{}; }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 05e8aa62a9..62ac2115cc 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -787,20 +787,12 @@ struct QuantGemmKernel } else { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -814,16 +806,10 @@ struct QuantGemmKernel } else if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -848,17 +834,10 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp similarity index 80% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp index facec252a3..06ca9854b9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp @@ -10,7 +10,7 @@ namespace ck_tile { template -struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +struct GemmMicroscalePipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase { using Base = GemmPipelineAgBgCrImplBase; using ADataType = typename Base::ADataType; @@ -42,10 +42,14 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - - using YPerTile = number; - using XPerTile = number; + using YPerTile = + std::conditional_t, + number, + number>; + using XPerTile = + std::conditional_t, + number, + number>; auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..a026694769 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,296 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmMicroscalePipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQLayout = remove_cvref_t; + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + // Support both RowMajor and ColumnMajor layouts for BQ + if constexpr(std::is_same_v) + { + return GetABQGlobalVectorLoadSize(); + } + else + { + return GetABQGlobalVectorLoadSize(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + // If we apply scale before writing to LDS, we need a tile distribution for + // BQuant consistent with global memory reading of matrix B, while + // if we apply scale after reading from LDS, we need a tile distribution for + // BQuant consistent with the MMA instructions layout + if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead) + { + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; + + using TileEncodingPattern = + tile_distribution_encoding_pattern_bq; + + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + else + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + + constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; + + // For each BQ layout we need different encodings whether B has the same layout or not + // TODO: generalize encodings for different BQuantGroupSize granularity + if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + { + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = KScale; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t N1 = NPerBlock / b_vec; + constexpr index_t N2 = b_vec; + + constexpr index_t KRepeatInWave = warp_size / N1; + constexpr index_t KRepeatAcrossWave = num_warps / KScale; + + constexpr index_t K2 = num_warps / KRepeatAcrossWave; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 2>>, + tuple, sequence<1, 1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + } + else + { + if constexpr(std::is_same_v) + { + constexpr index_t NScale = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t N0 = NScale / b_vec; + constexpr index_t N1 = b_vec; + + constexpr index_t KLanes = warp_size / N0; + constexpr index_t KVec = KPerBlock / KLanes / num_warps; + constexpr index_t KRepeat = KPerBlock / KScale / KVec; + + constexpr index_t KRepeatInWave = KRepeat > KLanes ? KLanes : 1; + constexpr index_t KRepeatAcrossWave = KRepeat > KLanes ? KRepeat / KLanes : 1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 0, 2>>, + tuple, sequence<1, 1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t KRepeatInWave = Problem::BQuantGroupSize::kK / b_vec; + constexpr index_t K1 = KScale; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / (KRepeatInWave * K1); + + // Number of contiguous elements in N dimension when reading B matrix + // becomes the vector size of BQ + constexpr index_t N2 = NPerBlock / (BlockSize / (KPerBlock / b_vec)); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 0, 1, 0>>, + tuple, sequence<1, 1, 1, 2>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + } + } + } + + // Return AttrNumAccess for a given warp tile (defined by ThreadElements) and data type + template + static constexpr auto GetAttrNumAccess(bool_constant, number) + { + constexpr index_t PackedSize = numeric_traits>::PackedSize; + constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(DataType) * PackedSize; + + return !UseLoadTranspose ? WGAttrNumAccessEnum::Single + : vector_size == ThreadElements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == ThreadElements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == ThreadElements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using ComputeDataType = typename Problem::ComputeDataType; + using LDSADataType = typename Problem::ADataType; + using LDSBDataType = std::conditional_t; + + static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of QuantGroupSize!"); + + constexpr auto thread_elements = + number{}; + + constexpr auto is_a_load_tr_v = bool_constant>{}; + constexpr auto is_b_load_tr_v = bool_constant>{}; + constexpr auto is_any_load_tr = is_a_load_tr_v || is_b_load_tr_v; + + constexpr auto wg_attr_num_access_compute = + GetAttrNumAccess(is_any_load_tr, thread_elements); + constexpr auto wg_attr_num_accessA = + std::is_same_v + ? wg_attr_num_access_compute + : GetAttrNumAccess(is_a_load_tr_v, thread_elements); + constexpr auto wg_attr_num_accessB = + std::is_same_v + ? wg_attr_num_access_compute + : GetAttrNumAccess(is_b_load_tr_v, thread_elements); + + using WarpGemm = WarpGemmDispatcher; + static_assert(is_any_of::value); + static_assert(std::is_same_v); + + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< + typename Problem::ADataType, + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + + return BQuantBlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp similarity index 60% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp index 7c448599ed..5a03057c64 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp @@ -9,7 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -18,15 +18,21 @@ namespace ck_tile { // B Tile Window: global memory // C Distributed tensor: register -template -struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +template +struct MicroscaleGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; - using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + using PipelineImplBase = GemmMicroscalePipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + + using BDqDataType = remove_cvref_t; + + static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + + using BLDSType = std::conditional_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using BDqDataType = remove_cvref_t; using BQDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -40,12 +46,16 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3>::PackedSize; + static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; + ck_tile::numeric_traits>::PackedSize; static constexpr index_t BQPackedSize = ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BLDSPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BQLayout = remove_cvref_t; using BLayout = remove_cvref_t; @@ -82,6 +92,9 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -165,6 +178,11 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; + static constexpr bool is_b_row_major = + std::is_same_v; + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; @@ -207,7 +225,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void ScaleTile(const TileType& block_tile, + CastTileType& block_tile_cast, + const ScaleTileType& scale_tile) + { + if constexpr(IsCastBeforeLDS) + { + constexpr auto b_block = TileType::get_distributed_spans(); + + // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 + // on gfx950 + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(false, "unsupported compute type"); + } + }; + + constexpr index_t BQuantGroupSizeIdx0 = + std::is_same_v + ? BQuantGroupSize::kN + : BQuantGroupSize::kK; + constexpr index_t BQuantGroupSizeIdx1 = + std::is_same_v + ? BQuantGroupSize::kK + : BQuantGroupSize::kN; + + // The input indices are with respect to B block tile. If B and Bq have different + // layouts, the indices must be swapped + auto make_bq_index = [](auto idx0, auto idx1) { + if constexpr(std::is_same_v) + { + return make_tuple( + tile_distributed_index{}, + tile_distributed_index{}); + } + else + { + return make_tuple( + tile_distributed_index{}, + tile_distributed_index{}); + } + }; + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + if constexpr(std::is_same_v) + { + if constexpr(idx1.impl_.at(0) % BPackedSize == 0) + { + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + + constexpr auto i_j_idx = make_tuple(idx0, idx1); + auto b_pack = block_tile[i_j_idx]; + + constexpr auto i_j_idx_scale_lo = make_bq_index(idx0, idx1_lo); + constexpr auto i_j_idx_scale_hi = make_bq_index(idx0, idx1_hi); + + // If the scale is the same for packed values, use pk cvt scale + // instructions, otherwise scale and cast element by element + if constexpr(i_j_idx_scale_lo[I0{}].impl_.at(0) == + i_j_idx_scale_hi[I0{}].impl_.at(0) && + i_j_idx_scale_lo[I1{}].impl_.at(0) == + i_j_idx_scale_hi[I1{}].impl_.at(0)) + { + float scale = float(scale_tile[i_j_idx_scale_lo]); + auto cvt = pk_mxfp4_to_compute_v2(b_pack, scale); + + block_tile_cast(i_j_idx_lo) = cvt.x; + block_tile_cast(i_j_idx_hi) = cvt.y; + } + else + { + float scale_lo = float(scale_tile[i_j_idx_scale_lo]); + auto b_f4_lo = + type_convert(b_pack.unpack(number<0>{})); + block_tile_cast(i_j_idx_lo) = type_convert( + type_convert(b_f4_lo) * scale_lo); + + float scale_hi = float(scale_tile[i_j_idx_scale_hi]); + auto b_f4_hi = + type_convert(b_pack.unpack(number<1>{})); + block_tile_cast(i_j_idx_hi) = type_convert( + type_convert(b_f4_hi) * scale_hi); + } + } + } + else + { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_bq_index(idx0, idx1); + float scale = float(scale_tile[i_j_idx_scale]); + + auto b_pack = block_tile[i_j_idx]; + block_tile_cast(i_j_idx) = + type_convert(type_convert(b_pack) * scale); + } + }); + }); + } + } + + template + CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const ElementwiseFunc& element_func) const + { + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, block_tile); + Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill(lds_window, block_tile, element_func); + } + } + + template + CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const TileTypeCast& block_tile_cast, + const ElementwiseFunc& element_func) const + { + // Fill LDS and apply the scale if IsCastBeforeLDS + auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) { + if constexpr(IsCastBeforeLDS) + { + return b_block_tile_cast; + } + else + { + return b_block_tile_orig; + } + }; + + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast)); + Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill( + lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func); + } + } + + template + CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm, + const AWindowType& a_lds_window, + const BWindowType& b_lds_window, + const QTileType& q_block_tile) const + { + // Load from LDS + // It can apply the scale and cast if we scale after reading from LDS + if constexpr(IsCastBeforeLDS) + { + block_gemm.LocalPrefetch( + a_lds_window, b_lds_window, is_a_load_tr_v, is_b_load_tr_v); + } + else + { + block_gemm.LocalPrefetch( + a_lds_window, b_lds_window, q_block_tile, is_a_load_tr_v, is_b_load_tr_v); + } + } + template > && std::is_same_v; constexpr bool is_bq_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); - static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + static_assert(is_bq_col_major + ? (NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}]), "Bq block window has incorrect lengths for defined BqLayout!"); static_assert(is_a_col_major @@ -347,13 +557,12 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3()); auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + // This defines the scaled and casted block tile for B matrix. + // Effectively, it is used only if we scale and cast before writing to LDS. + auto bdq_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + // Block GEMM auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using ABlockTile = @@ -402,114 +610,61 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(BBlockTileDistr{})); ABlockTile a_block_tile; - BBlockTile b_fp4_block_tile; + BBlockTile b_block_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex; constexpr ADramTileWindowStep a_dram_tile_window_step = is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2); + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr index_t b_scale_dram_tile_window_step = KPerBlock / BQuantGroupSize::kK; + constexpr BQDramTileWindowStep b_scale_dram_tile_window_step = + std::is_same_v + ? make_array(0, KPerBlock / BQuantGroupSize::kK) + : make_array(KPerBlock / BQuantGroupSize::kK, 0); // ----------------------------------------------------------------------------------------- // Gemm pipeline start - // prefetch - // global read 0 - // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; + // prefetch stages + + // Vmem -> Vgpr 0 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // BDataType - auto b_block_tile = make_static_distributed_tensor( - Policy::template MakeBRegTileDistribution()); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr 0 (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - constexpr auto idx1_js = tile_distributed_index<0>{}; - constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); - - // initialize C + // initialize C tile to zero tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); block_sync_lds(); - // LDS write 0 - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS 0 + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr 1 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); - - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); + // If we scale and cast before writing to LDS, + // we need to read another tile of Q matrix from Vmem, then scale and cast tile + if constexpr(IsCastBeforeLDS) + { + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step); + } + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // LDS -> Vgpr 0 + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); __builtin_amdgcn_sched_barrier(0); @@ -521,72 +676,34 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); - - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = - type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = - tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); + move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); + // Consume tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS -> Vgpr + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); i += 1; - // b_block_stride +=1; } while(i < (num_loop - 1)); } - // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) { @@ -596,35 +713,31 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS last tile + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS -> Vgpr last tile + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + + // Consume last tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); } @@ -653,9 +766,9 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + identity{}, b_dram_block_window_tmp, - [](const BDqDataType& b) { return b; }, + identity{}, bq_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp deleted file mode 100644 index 6cf9e22f41..0000000000 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "gemm_group_quant_utils.hpp" - -namespace ck_tile { - -struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy -{ - using Base = UniversalGemmPipelineAgBgCrPolicy; - using Base::I0; - using Base::I1; - using Base::I2; - - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() - { - using BQLayout = remove_cvref_t; - using BQDataType = remove_cvref_t; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; - - static_assert(std::is_same_v); - return GetABQGlobalVectorLoadSize(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution() - { - using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - // Tile: KPerBlock X NPerBlock - if constexpr(std::is_same_v) - { - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - return TileEncodingPattern::make_2d_static_tile_distribution(); - } - // Tile: NPerBlock X KPerBlock - else - { - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - return TileEncodingPattern::make_2d_static_tile_distribution(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() - { - // using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2 - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - constexpr index_t warp_size = get_warp_size(); - constexpr index_t num_warps = BlockSize / get_warp_size(); - constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); - constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; - constexpr index_t K0 = KPerBlock / b_vec; - constexpr index_t K1 = K0 / KScale; - constexpr index_t K3 = K0 / K1; - constexpr index_t K2 = 1; - - constexpr index_t N0 = num_warps / NumWaveGroups; - constexpr index_t N1 = warp_size / K0; - constexpr index_t N2 = NPerBlock / (N0 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2, 0>>, - tuple, sequence<1, 0, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - - static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0, - "KPerWarpGemm must be a multiple of QuantGroupSize!"); - - using WarpGemm = WarpGemmDispatcher; - static_assert(std::is_same_v || - std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v); - - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< - typename Problem::ADataType, - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>, - typename Problem::CDataType, - BlockWarps, - WarpGemm>; - - return BlockUniversalGemmAsBsCr{}; - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 9b02585e69..fdaebe8010 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -24,7 +24,8 @@ template + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase< ADataType_, @@ -82,6 +83,20 @@ struct GemmQuantPipelineProblemBase static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; + // gfx950 supports load with transpose for 4bit types, so we can transpose + // pk_fp4_t from LDS in registers. But without this instruction, + // the transpose is done in register between Vmem read and LDS write and + // the implementation does not support 4 bit types +#ifdef __gfx950__ + static constexpr auto BCastPolicy = BCastPolicy_; +#else + static constexpr auto BCastPolicy = + std::is_same_v && + std::is_same_v + ? CastPolicy::BeforeLDSWrite + : BCastPolicy_; +#endif + static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0); static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0); static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0); @@ -155,7 +170,8 @@ template + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase; + TailNum_, + BCastPolicy_>; template , + std::is_same_v, ADataType_, std::conditional_t>; // Calculate thresholds diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp index d491d89ef4..0e6e40b788 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using BQuant1D128Types = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_128.cpp new file mode 100644 index 0000000000..94572a80dc --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_128.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // CCR BQ: C + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_64.cpp new file mode 100644 index 0000000000..c6d1f0c341 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_ccr_1d_64.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using FP16 = ck_tile::fp16_t; +using BF16 = ck_tile::bf16_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // CCR BQ: C + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_128.cpp new file mode 100644 index 0000000000..e8744eb35a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_128.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // CRR BQ: C + std::tuple, + // CRR BQ: R + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_64.cpp new file mode 100644 index 0000000000..dbc1ae7f2a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_crr_1d_64.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // CRR BQ: C + std::tuple, + // CRR BQ: R + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_128.cpp new file mode 100644 index 0000000000..7637b8a12a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_128.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using FP16 = ck_tile::fp16_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // RCR BQ: C + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP16, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, FP8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, BF8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, PkFP4, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize128>, + // RCR BQ: R + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128> +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_64.cpp new file mode 100644 index 0000000000..aa960ca16e --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rcr_1d_64.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using FP16 = ck_tile::fp16_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // RCR BQ: C + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP16, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, FP8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, FP8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, BF8, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP16, PkFP4, E8M0, FP16, BQuantGrouped, GemmConfigMx, GroupSize64>, + // RCR BQ: R + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, ColumnMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64> +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_128.cpp new file mode 100644 index 0000000000..f181b432d4 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_128.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // RRR BQ: C + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize128>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128>, + // RRR BQ: R + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize128> +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_64.cpp new file mode 100644 index 0000000000..a02136b7db --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_microscale_rrr_1d_64.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + // RRR BQ: C + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMxFP4, GroupSize64>, + std::tuple< RowMajor, RowMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>, + // RRR BQ: R + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64> +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 11fa6e038a..5a26034182 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -102,13 +102,24 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -struct GemmConfigMxFp4 : public GemmConfigBase +struct GemmConfigMx : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128; }; +// This configuration uses K_Warp_Tile = 64 on CDNA. In this way, on gfx950 we can use +// LDS load transpose on matrix B (FP4) because the instruction requires each +// lane to load 16 4bits elements +struct GemmConfigMxFP4 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool APreshuffleQuant = true; @@ -666,8 +677,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; + const ck_tile::index_t stride_B = K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -678,24 +688,36 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? K / 2 : K, - N, - stride_B, - this->is_row_major(BLayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); } else { ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{range_min, range_max}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(bq_bqk_bqn, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); } @@ -780,14 +802,15 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase) - ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mx_gemm_bquant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant + ? ck_tile::CastPolicy::BeforeLDSWrite + : ck_tile::CastPolicy::AfterLDSRead; using PipelineProblem = ck_tile::GemmBQuantPipelineProblem; + tail_number_v, + b_cast_policy_v>; using GemmPipeline = std::conditional_t< PreshuffleB == false, - std::conditional_t, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::MicroscaleGemmPipelineAgBgCrCompV3, ck_tile::BQuantGemmPipelineAgBgCrCompV3>, ck_tile::WPQuantBPipelineAgBgCrV2>; using GemmEpilogue = ck_tile::CShuffleEpilogue, + std::conditional_t, ADataType, BDataType>, ck_tile::tuple<>,