diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 2ddb96f620..47a22cdcba 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -466,41 +466,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); - if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if(K % AQuantGroupSize::kK != 0) - { - throw std::runtime_error( - "K must be aligned with QuantGroupSize for AQuantGrouped mode"); - } - } - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if(K % BQuantGroupSize::kK != 0) - { - throw std::runtime_error( - "K must be aligned with QuantGroupSize for BQuantGrouped mode"); - } - } - if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if(K % AQuantGroupSize::kK != 0) - { - throw std::runtime_error( - "K must be aligned with QuantGroupSize for ABQuantGrouped mode"); - } - if(K % BQuantGroupSize::kK != 0) - { - throw std::runtime_error( - "K must be aligned with QuantGroupSize for ABQuantGrouped mode"); - } - if(K % BQuantGroupSize::kN != 0) - { - throw std::runtime_error( - "N must be aligned with QuantGroupSize for ABQuantGrouped mode"); - } - } - ck_tile::index_t AQK, BQK, BQN = 0; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 05c98e7bb5..9ad5af8264 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -142,99 +142,98 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, const std::size_t K = a_m_k.get_length(1); auto f_mn = [&](auto m, auto n) { - AccDataType v_acc = 0, v_block_acc = 0; + AccDataType v_acc = 0; - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v); - static_assert(std::is_same_v || - std::is_same_v); - for(std::size_t k = 0; k < K; ++k) - { - AccDataType v_a; - AccDataType v_b; + constexpr std::size_t kGroupK = BQuantGroupSize::kK; + + // ---- A loader: dequant A(m,k) into AccDataType ---- + auto load_a = [&](std::size_t k) -> AccDataType { if constexpr(std::is_same_v) { const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - if(k % 2 == 1) - v_a = fp32_val.hi; - else - v_a = fp32_val.lo; + return (k & 1) ? fp32_val.hi : fp32_val.lo; } else { - v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + return ck_tile::type_convert(a_element_op(a_m_k(m, k))); } + }; + // ---- B loader: dequant B(k,n) into AccDataType ---- + auto load_b = [&](std::size_t k) -> AccDataType { if constexpr(std::is_same_v) { const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - if(k % 2 == 1) - v_b = fp32_val.hi; - else - v_b = fp32_val.lo; + return (k & 1) ? fp32_val.hi : fp32_val.lo; } else if constexpr(std::is_same_v) { - v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); + return fp8_to_float_raw(b_element_op(b_k_n(k, n))); } else { - v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + return ck_tile::type_convert(b_element_op(b_k_n(k, n))); } - v_block_acc += v_a * v_b; + }; - // Apply group dequant scale - if((k + 1) % BQuantGroupSize::kK == 0) + // ---- a scale loader for a given K-group index ---- + auto load_scale_a = [&](ck_tile::index_t k_group) -> float { + const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM; + const ck_tile::index_t inner_dim = k_group; + + if constexpr(std::is_same_v) { - float a_scale = 0.f; - float b_scale = 0.f; - // A scale - index_t outer_dim = m / AQuantGroupSize::kM; - index_t inner_dim = k / AQuantGroupSize::kK; - if constexpr(std::is_same_v) - { - a_scale = a_q(outer_dim, inner_dim); - } - else if constexpr(std::is_same_v) - { - a_scale = fp8_to_float_raw(a_q(outer_dim, inner_dim)); - } - else if constexpr(std::is_same_v) - { - a_scale = bf8_to_float_raw(a_q(outer_dim, inner_dim)); - } - else - { - static_assert(false, "Unexpected Q datatype."); - } - // B scale - outer_dim = k / BQuantGroupSize::kK; - inner_dim = n / BQuantGroupSize::kN; - if constexpr(std::is_same_v) - { - b_scale = b_q(outer_dim, inner_dim); - } - else if constexpr(std::is_same_v) - { - b_scale = fp8_to_float_raw(b_q(outer_dim, inner_dim)); - } - else if constexpr(std::is_same_v) - { - b_scale = bf8_to_float_raw(b_q(outer_dim, inner_dim)); - } - else - { - static_assert(false, "Unexpected Q datatype."); - } - v_block_acc = v_block_acc * a_scale * b_scale; - v_acc += v_block_acc; - v_block_acc = 0; + return a_q(outer_dim, inner_dim); } + else if constexpr(std::is_same_v) + { + return fp8_to_float_raw(a_q(outer_dim, inner_dim)); + } + else // QDataType == bf8_t by static_assert above + { + return bf8_to_float_raw(a_q(outer_dim, inner_dim)); + } + }; + // ---- b scale loader for a given K-group index ---- + auto load_scale_b = [&](ck_tile::index_t k_group) -> float { + const ck_tile::index_t outer_dim = k_group; + const ck_tile::index_t inner_dim = n / BQuantGroupSize::kN; + + if constexpr(std::is_same_v) + { + return b_q(outer_dim, inner_dim); + } + else if constexpr(std::is_same_v) + { + return fp8_to_float_raw(b_q(outer_dim, inner_dim)); + } + else // QDataType == bf8_t by static_assert above + { + return bf8_to_float_raw(b_q(outer_dim, inner_dim)); + } + }; + // ---- Loop over K by groups (full and tail) ---- + for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK) + { + const std::size_t k_end = std::min(k_begin + kGroupK, K); + + AccDataType v_block_acc = 0; + + // unscaled accumulation within this K-group + for(std::size_t k = k_begin; k < k_end; ++k) + { + const AccDataType v_a = load_a(k); + const AccDataType v_b = load_b(k); + v_block_acc += v_a * v_b; + } + + const ck_tile::index_t k_group = static_cast(k_begin / kGroupK); + const float scale_a = load_scale_a(k_group); + const float scale_b = load_scale_b(k_group); + + v_acc += v_block_acc * scale_a * scale_b; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 8e37cae359..ba67a9ee4d 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -412,32 +412,6 @@ struct QuantGemmKernel return false; } - if constexpr(kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) - { - if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!"); - } - return false; - } - } - - if constexpr(kQuantType == QuantType::BQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) - { - if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!"); - } - return false; - } - } - if constexpr(std::is_same_v) { if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 46b02b4b0b..f89aea1c17 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -25,14 +25,20 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr test_gemm_quant_aquant_base_ccr.cpp ) - # ABQuant tests - add_gtest_executable(test_tile_gemm_quant_abquant - test_gemm_quant_abquant.cpp - ) - target_compile_options(test_tile_gemm_quant_abquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # ABQuant tests + add_gtest_executable(test_tile_gemm_quant_abquant_base + test_gemm_quant_abquant_base.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_padding + test_gemm_quant_abquant_padding.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_aquant_prefill test_gemm_quant_aquant_prefill.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp similarity index 100% rename from test/ck_tile/gemm_block_scale/test_gemm_quant_abquant.cpp rename to test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp new file mode 100644 index 0000000000..5247a4405d --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant padding padding tests +// Tuple format: +// clang-format off +using ABQuantPaddingTypes = ::testing::Types< + std::tuple +>; +// clang-format on + +// Test suite for ABQuant Padding +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPaddingTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 3ecbbf046b..8c9955da74 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -80,6 +80,10 @@ class TestCkTileGemmQuantBase : public ::testing::Test static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; + static constexpr bool kPadM = GemmConfig::kPadM; + static constexpr bool kPadN = GemmConfig::kPadN; + static constexpr bool kPadK = GemmConfig::kPadK; + public: void SetUp() override { static_cast(this)->SetUpQuantTypeSpecific(); } @@ -88,9 +92,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test // Common test execution logic void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) { - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; // WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize. // static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % // VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 24a05d6267..7d82958acf 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -83,6 +83,12 @@ struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase static constexpr bool TransposeC = true; }; +struct GemmConfigPadding : public GemmConfigBase +{ + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; +}; + struct GemmConfigPreshuffleBDecode : public GemmConfigBase { static constexpr bool PreshuffleB = true;