From 3a18b2f29e36478a6be871a16cffd4576d98ac23 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 30 Jan 2026 12:22:40 +0000 Subject: [PATCH] Merge commit '6a6177a246d6c81932fbb1061ad6a62e90b788a1' into develop --- .../gemm_abquant_quantgrouped.cpp | 30 +++++ .../38_block_scale_gemm/gemm_quant.cpp | 2 +- .../run_gemm_quant_example.inc | 124 +++++++++++------- include/ck_tile/core.hpp | 1 + .../core/arch/amd_buffer_addressing.hpp | 3 +- .../arch/amd_buffer_addressing_builtins.hpp | 2 +- include/ck_tile/core/numeric/pk_fp4.hpp | 88 ++++++++++++- include/ck_tile/core/numeric/pk_int4.hpp | 11 ++ include/ck_tile/core/numeric/vector_type.hpp | 1 + .../core/utility/mixed_prec_compute_type.hpp | 54 ++++++++ include/ck_tile/core/utility/type_traits.hpp | 17 +++ .../ck_tile/host/reference/reference_gemm.hpp | 82 ++++++------ .../ops/common/load_interleaved_pk_type.hpp | 19 ++- .../unary_element_wise_operation.hpp | 23 ++++ ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 20 ++- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 15 ++- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 19 ++- .../gemm_abquant_pipeline_ag_bg_cr_policy.hpp | 5 +- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 36 ++--- .../pipeline/gemm_quant_pipeline_problem.hpp | 30 +++-- ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 9 +- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 51 ++++--- test/ck_tile/gemm_block_scale/CMakeLists.txt | 16 +++ .../test_gemm_quant_abquant_a4w4_base.cpp | 44 +++++++ .../test_gemm_quant_abquant_a4w4_padding.cpp | 65 +++++++++ ...est_gemm_quant_abquant_a4w4_preshuffle.cpp | 44 +++++++ .../gemm_block_scale/test_gemm_quant_base.hpp | 2 +- .../test_gemm_quant_fixtures.hpp | 4 +- 28 files changed, 642 insertions(+), 175 deletions(-) create mode 100644 include/ck_tile/core/utility/mixed_prec_compute_type.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index a7cb88079b..e4e0503b5a 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -164,5 +164,35 @@ static auto _ = []() { BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings( + {"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; return 0; }(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 1fbe4d7b47..cc4302a992 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8") + "or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 665c7828ad..540d5725dd 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -9,6 +9,7 @@ #include #include #include +#include #include "ck_tile/core/config.hpp" #include "ck_tile/ops/common/utils.hpp" @@ -35,10 +36,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str static_assert(std::is_same_v); constexpr bool transpose_c = GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped; - using ComputeDataType = std::conditional_t; + + // Use automatically determined compute type from + using ComputeDataType = void; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -80,7 +80,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, ck_tile::BaseGemmPipelineAgBgCrMem, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::ABQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>>; const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -182,30 +185,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - using GemmEpilogue = ck_tile::CShuffleEpilogue, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - 1, - false, - 1, - TiledPermuteN>>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -557,8 +558,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { if constexpr(std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -594,18 +594,26 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else + { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( @@ -723,12 +731,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } else { @@ -804,12 +811,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); } else { @@ -984,10 +990,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if(arg_parser.get_int("v") == 1) { + std::cout << "Performing CPU verification..." << std::endl; + ck_tile::HostTensor c_m_n_host_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); + // Track start time for reference operation + auto start_reference_tick = std::chrono::high_resolution_clock::now(); if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( @@ -1061,6 +1074,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); + // "Stop" our timer + auto verification_finished_tick = std::chrono::high_resolution_clock::now(); + if(!pass) { std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) @@ -1068,6 +1084,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, << std::endl; } std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + // Calculate and display reference timing + using DurationType = std::chrono::duration; + double reference_sec = std::chrono::duration_cast(verification_finished_tick - + start_reference_tick) + .count(); + double verification_sec = std::chrono::duration_cast( + verification_finished_tick - start_verification_tick) + .count(); + float reference_msec = static_cast(reference_sec * 1e3); + float verification_msec = static_cast(verification_sec * 1e3); + + std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took " + << reference_msec << "ms, verification took " << verification_msec << "ms." + << std::endl; } else if(arg_parser.get_int("v") == 2) { @@ -1098,6 +1129,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) } if constexpr(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index f3596df9bd..438e44f5f1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -91,6 +91,7 @@ #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/mixed_prec_compute_type.hpp" #include "ck_tile/core/utility/persistent_async_input_scheduler.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/print.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7af2f558ad..8f9dd30bda 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1544,7 +1544,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 9f9770df1b..42886b8ced 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1414,7 +1414,7 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index cc23ce71a8..d74db6b336 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -6,6 +6,7 @@ #include #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" #if defined(__gfx950__) @@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +#if CK_TILE_USE_CUSTOM_DATA_TYPE +using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2))); +#else +using fp8x2_t = fp8_t __attribute__((ext_vector_type(2))); +#endif + // Helpers: constexpr-safe access to elements of ext_vector_type(2) // Some compilers don't allow operator[] in constant expressions for vector types. // We use bit_cast to a trivially copyable representation to extract lanes. @@ -98,6 +105,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); } CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } @@ -105,6 +114,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); } CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); } CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } + CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); } + CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); } template CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const @@ -145,6 +156,49 @@ struct pk_float4_e2m1_t bit_cast(static_cast(0xC400)), // -4 bit_cast(static_cast(0xC600)) // -6 }; + +#if CK_TILE_USE_OCP_FP8 + // FP8 EM4E3 (OCP) representation + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x30)), // 0.5 + fp8_t(static_cast(0x38)), // 1 + fp8_t(static_cast(0x3C)), // 1.5 + fp8_t(static_cast(0x40)), // 2 + fp8_t(static_cast(0x44)), // 3 + fp8_t(static_cast(0x48)), // 4 + fp8_t(static_cast(0x4C)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB0)), // -0.5 + fp8_t(static_cast(0xB8)), // -1 + fp8_t(static_cast(0xBC)), // -1.5 + fp8_t(static_cast(0xC0)), // -2 + fp8_t(static_cast(0xC4)), // -3 + fp8_t(static_cast(0xC8)), // -4 + fp8_t(static_cast(0xCC)) // -6 + }; +#else // CK_TILE_USE_FNUZ_FP8 + // FP8 E4M3 FNUZ + static constexpr fp8_t e2m1_to_fp8_table[16] = { + fp8_t(static_cast(0x00)), // 0 + fp8_t(static_cast(0x38)), // 0.5 + fp8_t(static_cast(0x40)), // 1 + fp8_t(static_cast(0x44)), // 1.5 + fp8_t(static_cast(0x48)), // 2 + fp8_t(static_cast(0x4C)), // 3 + fp8_t(static_cast(0x50)), // 4 + fp8_t(static_cast(0x54)), // 6 + fp8_t(static_cast(0x00)), // -0 + fp8_t(static_cast(0xB8)), // -0.5 + fp8_t(static_cast(0xC0)), // -1 + fp8_t(static_cast(0xC4)), // -1.5 + fp8_t(static_cast(0xC4)), // -2 + fp8_t(static_cast(0xCC)), // -3 + fp8_t(static_cast(0xD0)), // -4 + fp8_t(static_cast(0xD4)) // -6 + }; +#endif + #endif }; @@ -408,6 +462,27 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; + // #endif +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + // NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8 + // would be better than the naive implementation below + // #if CK_TILE_FP4_CVT_DEVICE + // return impl::_from_f4(data, scale); + // #else + return fp8x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; + // #endif +} #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { @@ -415,7 +490,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, + e2m1_to_fp32_table[_unpack(number<1>{})] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { @@ -428,6 +504,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * scale)}; } +CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const +{ + return type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale; +} +CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const +{ + return fp8x2_t{ + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)}; +} #endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index d5df4d1917..9eb62a6ec4 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/random.hpp" #include @@ -23,6 +24,11 @@ struct pk_int4_t type data; CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {} CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {} + + // NOTE: added for interface compatibility with pk_fp4_t + // Other data types could be added for greater similarity + CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const; + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } }; // limits @@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const +{ + return pk_int4_t_to_fp32x2_t(*this); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 90ddc2a56e..def054f415 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/e8m0.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/utility/mixed_prec_compute_type.hpp b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp new file mode 100644 index 0000000000..021763c108 --- /dev/null +++ b/include/ck_tile/core/utility/mixed_prec_compute_type.hpp @@ -0,0 +1,54 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +#include + +namespace ck_tile { + +namespace detail { + +// Helper method to automatically determine compute type +// Selects the largest type of the two. If both of them are packed data types, defaults to fp8. +template +struct auto_compute_type +{ + using LargestInputType = largest_type_t; + + // Sanity check: there are no packed types larger than 1 byte yet, but if we add them + // this logic should change + static_assert(!is_packed_type_v || sizeof(LargestInputType) == sizeof(fp8_t)); + + using type = std::conditional_t, fp8_t, LargestInputType>; +}; + +// Helper method to determine compute type, defaulting an explicitly passed-in compute type +template +struct mixed_prec_compute_type +{ + using type = std::conditional_t, + typename auto_compute_type::type, + ComputeDataType>; +}; + +} // namespace detail + +template +using mixed_prec_compute_type_t = + typename detail::mixed_prec_compute_type::type; + +// Helper method to determine compute type, defaulting to input data type +// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed, +// ComputeDataType is used. +template +using mixed_prec_compute_type_from_input_t = std::conditional_t< + is_packed_type_v, + std::conditional_t, ComputeDataType, OtherDataType>, + ThisDataType>; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index f07e25e19c..c11d180839 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + #include #include #include @@ -187,4 +189,19 @@ template using tuple_element_or_default_t = typename tuple_element_or_default::type; +// Helper struct to determine if a type is packed (more than 1 element per byte) +template +struct is_packed_type +{ + static constexpr bool value = numeric_traits::PackedSize > 1; +}; + +template +static constexpr bool is_packed_type_v = is_packed_type::value; + +// Helper definition to take the largest sizes type +template +using largest_type_t = + std::conditional_t= sizeof(BDataType), ADataType, BDataType>; + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 9ad5af8264..7830150b63 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -137,47 +137,55 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, const BElementOp& b_element_op = {}, const ACCElementOp& acc_element_op = {}) { - const std::size_t M = a_m_k.get_length(0); - const std::size_t N = b_k_n.get_length(1); - const std::size_t K = a_m_k.get_length(1); + constexpr auto A_TENSOR_M_DIM = 0; + constexpr auto A_TENSOR_K_DIM = 1; + constexpr auto B_TENSOR_K_DIM = 0; + constexpr auto B_TENSOR_N_DIM = 1; + + const std::size_t M = a_m_k.get_length(A_TENSOR_M_DIM); + const std::size_t N = b_k_n.get_length(B_TENSOR_N_DIM); + const std::size_t K = a_m_k.get_length(A_TENSOR_K_DIM); + + // Pre-convert A/B tensors to AccData type + // This prevents doing slow reconversions for each row/column + HostTensor a_acc(a_m_k.mDesc); + HostTensor b_acc(b_k_n.mDesc); + + a_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const ADataType pk_val = a_element_op(a_m_k(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[A_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo; + } + else + { + self(index) = ck_tile::type_convert(a_element_op(a_m_k(index))); + } + }); + + b_acc.ForEach([&](auto& self, auto index) { + if constexpr(std::is_same_v || std::is_same_v) + { + const BDataType pk_val = b_element_op(b_k_n(index)); + const fp32x2_t fp32_val = pk_val.to_fp32x2(); + self(index) = (index[B_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + self(index) = fp8_to_float_raw(b_element_op(b_k_n(index))); + } + else + { + self(index) = ck_tile::type_convert(b_element_op(b_k_n(index))); + } + }); auto f_mn = [&](auto m, auto n) { AccDataType v_acc = 0; constexpr std::size_t kGroupK = BQuantGroupSize::kK; - // ---- A loader: dequant A(m,k) into AccDataType ---- - auto load_a = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else - { - return ck_tile::type_convert(a_element_op(a_m_k(m, k))); - } - }; - - // ---- B loader: dequant B(k,n) into AccDataType ---- - auto load_b = [&](std::size_t k) -> AccDataType { - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - return (k & 1) ? fp32_val.hi : fp32_val.lo; - } - else if constexpr(std::is_same_v) - { - return fp8_to_float_raw(b_element_op(b_k_n(k, n))); - } - else - { - return ck_tile::type_convert(b_element_op(b_k_n(k, n))); - } - }; - // ---- a scale loader for a given K-group index ---- auto load_scale_a = [&](ck_tile::index_t k_group) -> float { const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM; @@ -224,8 +232,8 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, // unscaled accumulation within this K-group for(std::size_t k = k_begin; k < k_end; ++k) { - const AccDataType v_a = load_a(k); - const AccDataType v_b = load_b(k); + const AccDataType v_a = a_acc(m, k); + const AccDataType v_b = b_acc(k, n); v_block_acc += v_a * v_b; } diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 10c2a1e4df..3f1a3b8f1c 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -4,11 +4,12 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { -template +template struct InterleavedPKTypeLoader { template @@ -21,10 +22,15 @@ struct InterleavedPKTypeLoader constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; const auto in_dstr_tensors = load_tile(warp_window); - using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); + // NOTE: we rely on types packing neatly here + using RawSrcType = typename SrcDataType::type; + constexpr auto PackedSize = numeric_traits::PackedSize; + + using SrcVectorType = ext_vector_t; + using DstVectorType = ext_vector_t; static_for<0, thread_buffer_size, 1>{}([&](auto i) { elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); + in_dstr_tensors.get_thread_buffer().template get_as()[i]); }); } }; @@ -37,10 +43,11 @@ template CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(std::is_same_v) + if constexpr(is_packed_type_v) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); + InterleavedPKTypeLoader::load_interleaved_pk_type( + dst, src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index ca9af0a7a8..3f58eceb33 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -397,6 +397,29 @@ struct PassThroughPack8 y.hi = i4_to_bf8x4(bit_cast(x) >> 8); #endif } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const + { + pk_fp4_t f0 = pk_fp4_t{x[0]}; + pk_fp4_t f1 = pk_fp4_t{x[1]}; + pk_fp4_t f2 = pk_fp4_t{x[2]}; + pk_fp4_t f3 = pk_fp4_t{x[3]}; + + fp8x2_t x0 = f0.to_fp8x2(); + fp8x2_t x1 = f1.to_fp8x2(); + fp8x2_t x2 = f2.to_fp8x2(); + fp8x2_t x3 = f3.to_fp8x2(); + + y[0] = x0[0]; + y[1] = x0[1]; + y[2] = x1[0]; + y[3] = x1[1]; + y[4] = x2[0]; + y[5] = x2[1]; + y[6] = x3[0]; + y[7] = x3[1]; + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1784436f87..0044b412ec 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" @@ -255,17 +256,26 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + + // Determine compute types to use + // This logic defaults to A/B DataType, but if one of them is packed falls back to the other + // If both are packed, it falls back to the explicitly defined ComputeDataType in the + // problem It might be a good idea to use ComputeDataType anyway, but that would break how + // this behaviour used to work + using ATypeToUse = mixed_prec_compute_type_from_input_t; + using BTypeToUse = mixed_prec_compute_type_from_input_t; + constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; using BDataType = typename Problem::BDataType; constexpr index_t KLaneBytes = KLane / numeric_traits::PackedSize * sizeof(BDataType); constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -189,7 +191,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg typename BFlatBlockTensor, typename AQBlockTensor, typename BQBlockTensor, - typename ABlockWindow> + typename ABlockWindow, + index_t UnaryOpSize = 8> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, ABlockTensor& a_warp_tensor, BFlatBlockTensor& b_warp_tensor, @@ -249,8 +252,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows(number{})(number{})); + + load_int4_tile( + a_warp_tensor(number{}), + a_warp_windows(number{})(number{})); } // barrier // Could be deleted diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 2d28b813bf..d79bd31489 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -108,9 +108,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // 4. i4, bf8, (fp8/fp32) -> f32 static_assert( (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || - std::is_same_v) && + std::is_same_v || + std::is_same_v) && (std::is_same_v || std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v || @@ -135,12 +137,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = std::conditional_t< - std::is_same_v && - std::is_same_v, - ADataType, - BDataType>; + // A/B DataType get converted from PkInt4/PkFp4 during loading + using OverrideADataType = ComputeDataType; + using OverrideBDataType = ComputeDataType; using Base = BlockGemmQuantBase; using WarpGemm = remove_cvref_t; @@ -268,9 +267,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( + // If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS + load_int4_tile( a_warp_tile_, a_block_window); - // If B datatype were pkint4 it would be converted prior to storing in LDS load_int4_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp index 095275e60b..b636bfa4b7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_policy.hpp @@ -10,9 +10,10 @@ namespace ck_tile { -struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +struct GemmABQuantPipelineAgBgCrDefaultPolicy + : public UniversalGemmBasePolicy { - using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base = UniversalGemmBasePolicy; using Base::I0; using Base::I1; using Base::I2; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index 5902dd0c4f..cfd12313e8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -34,9 +34,6 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using AQuantGroupSize = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; - // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!"); @@ -67,6 +64,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3())>; + // A/B DataType gets converted from PkInt4/PkFp4 during loading + using OverrideADataType = BlockGemm::OverrideADataType; + using OverrideBDataType = BlockGemm::OverrideBDataType; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -281,9 +282,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -303,9 +304,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = decltype(make_static_distributed_tensor(AQBlockTileDistr{})); using BQBlockTile = @@ -361,7 +362,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -373,7 +374,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -409,7 +410,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ABDataType PkInt4/PkFp4 gets converted during loading earlier + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -420,7 +422,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); @@ -493,7 +495,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: ADataType gets converted during loading from PkInt4/PkFp4 + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -543,9 +546,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, m, @@ -593,9 +596,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + // Note: ADataType PkInt4/PkFp4 gets converted during loading + [](const OverrideADataType& a) { return a; }, b_dram_block_window_tmp, - // Note: BDataType PkInt4 gets converted during loading + // Note: BDataType PkInt4/PkFp4 gets converted during loading [](const OverrideBDataType& b) { return b; }, aq_dram_block_window_tmp, bq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 1edbe9ac16..9b02585e69 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -21,23 +21,27 @@ template -struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase +struct GemmQuantPipelineProblemBase + : public GemmPipelineProblemBase< + ADataType_, + BDataType_, + CDataType_, + BlockGemmShape_, + Traits_, + mixed_prec_compute_type_t> { - using Base = GemmPipelineProblemBase; + + using Base = GemmPipelineProblemBase< + ADataType_, + BDataType_, + CDataType_, + BlockGemmShape_, + Traits_, + mixed_prec_compute_type_t>; using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index ae2a601f8a..f136b86314 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -95,11 +95,6 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using BTypeToUse = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; using BDataType = typename Problem::BDataType; @@ -107,8 +102,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel KLane / numeric_traits::PackedSize * sizeof(BDataType); constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" @@ -239,36 +240,42 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe make_tensor_view(p_a_lds_pong, a_lds_block_desc); // A DRAM tile window for load + auto a_dram_tile_distribution = + PipelinePolicy::template MakeADramTileDistribution(); + auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + a_dram_tile_distribution); // ping-pong window for A LDS + auto a_warp_tile_distribution = + make_static_tile_distribution(typename WG::AWarpDstrEncoding{}); + auto a_warp_window_ping_tmp = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); auto a_warp_window_pong_tmp = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + a_warp_tile_distribution); statically_indexed_array< statically_indexed_array, @@ -314,7 +321,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe b_flat_distribution); using BTypeToUse = - std::conditional_t, ADataType, BDataType>; + mixed_prec_compute_type_from_input_t; using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); // pingpong buffer for B @@ -354,7 +361,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -393,15 +400,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe block_sync_lds(); // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; + using ATypeToUse = + mixed_prec_compute_type_from_input_t; + using ATileType = + decltype(make_static_distributed_tensor(a_warp_tile_distribution)); + statically_indexed_array a_warp_tensor; static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -434,7 +443,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -450,8 +459,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // Next K @@ -463,7 +472,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -495,8 +504,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); iCounter--; HotLoopScheduler(); @@ -513,7 +522,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_int4_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -535,8 +544,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + load_int4_tile( + a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); // GEMM loopK diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 9dd9670ff5..8e005d588e 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -76,6 +76,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base + test_gemm_quant_abquant_a4w4_base.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_padding + test_gemm_quant_abquant_a4w4_padding.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_preshuffle + test_gemm_quant_abquant_a4w4_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_a4w4_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant test_gemm_quant_abquant_preshuffleQuant.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp new file mode 100644 index 0000000000..5e2403f7d1 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_base.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp new file mode 100644 index 0000000000..1e496d5b64 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_padding.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false + // RCR layout with RowMajor AQ, ColumnMajor BQ + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadK) +{ + this->run_test_with_validation(1024, 1024, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadN) +{ + this->run_test_with_validation(1024, 832, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadM) +{ + this->run_test_with_validation(832, 1024, 1024); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadMNK) +{ + this->run_test_with_validation(832, 832, 832); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadNK) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp new file mode 100644 index 0000000000..43051c8d08 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_a4w4_preshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using Half = ck_tile::half_t; +using PkFP4 = ck_tile::pk_fp4_t; +using ABQuantGrouped = + std::integral_constant; + +// 1d block sizes for AQuant +using GroupSize1D = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // RCR layout with RowMajor AQ, ColumnMajor BQ + // PreshuffleB = true && TransposeC = false + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 7be4131db4..5937b44229 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -209,7 +209,7 @@ template <> struct QuantTypeTraits { template - using ComputeDataType = BDataType; // For AQuant, compute type is BDataType + using ComputeDataType = void; // Use automatically determined compute type static constexpr const char* name = "abquant"; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 9683fa98aa..0033bb42a8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -1174,8 +1174,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase>; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType,