mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Enable padding blockscale for abquant (#3453)
* Enable padding blockscale for abquant * run clang-format * Reduce unnecessary testing * remove cout
This commit is contained in:
@@ -466,41 +466,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if(K % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for AQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if(K % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if(K % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for ABQuantGrouped mode");
|
||||
}
|
||||
if(K % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for ABQuantGrouped mode");
|
||||
}
|
||||
if(K % BQuantGroupSize::kN != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"N must be aligned with QuantGroupSize for ABQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t AQK, BQK, BQN = 0;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
|
||||
@@ -142,99 +142,98 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0, v_block_acc = 0;
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, pk_int4_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> ||
|
||||
std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
constexpr std::size_t kGroupK = BQuantGroupSize::kK;
|
||||
|
||||
// ---- A loader: dequant A(m,k) into AccDataType ----
|
||||
auto load_a = [&](std::size_t k) -> AccDataType {
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
return (k & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
};
|
||||
|
||||
// ---- B loader: dequant B(k,n) into AccDataType ----
|
||||
auto load_b = [&](std::size_t k) -> AccDataType {
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
return (k & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
v_block_acc += v_a * v_b;
|
||||
};
|
||||
|
||||
// Apply group dequant scale
|
||||
if((k + 1) % BQuantGroupSize::kK == 0)
|
||||
// ---- a scale loader for a given K-group index ----
|
||||
auto load_scale_a = [&](ck_tile::index_t k_group) -> float {
|
||||
const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM;
|
||||
const ck_tile::index_t inner_dim = k_group;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
float a_scale = 0.f;
|
||||
float b_scale = 0.f;
|
||||
// A scale
|
||||
index_t outer_dim = m / AQuantGroupSize::kM;
|
||||
index_t inner_dim = k / AQuantGroupSize::kK;
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
a_scale = a_q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
a_scale = fp8_to_float_raw(a_q(outer_dim, inner_dim));
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
a_scale = bf8_to_float_raw(a_q(outer_dim, inner_dim));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected Q datatype.");
|
||||
}
|
||||
// B scale
|
||||
outer_dim = k / BQuantGroupSize::kK;
|
||||
inner_dim = n / BQuantGroupSize::kN;
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
b_scale = b_q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
b_scale = fp8_to_float_raw(b_q(outer_dim, inner_dim));
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
b_scale = bf8_to_float_raw(b_q(outer_dim, inner_dim));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected Q datatype.");
|
||||
}
|
||||
v_block_acc = v_block_acc * a_scale * b_scale;
|
||||
v_acc += v_block_acc;
|
||||
v_block_acc = 0;
|
||||
return a_q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
return fp8_to_float_raw(a_q(outer_dim, inner_dim));
|
||||
}
|
||||
else // QDataType == bf8_t by static_assert above
|
||||
{
|
||||
return bf8_to_float_raw(a_q(outer_dim, inner_dim));
|
||||
}
|
||||
};
|
||||
// ---- b scale loader for a given K-group index ----
|
||||
auto load_scale_b = [&](ck_tile::index_t k_group) -> float {
|
||||
const ck_tile::index_t outer_dim = k_group;
|
||||
const ck_tile::index_t inner_dim = n / BQuantGroupSize::kN;
|
||||
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
return b_q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
return fp8_to_float_raw(b_q(outer_dim, inner_dim));
|
||||
}
|
||||
else // QDataType == bf8_t by static_assert above
|
||||
{
|
||||
return bf8_to_float_raw(b_q(outer_dim, inner_dim));
|
||||
}
|
||||
};
|
||||
// ---- Loop over K by groups (full and tail) ----
|
||||
for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
|
||||
{
|
||||
const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
|
||||
|
||||
AccDataType v_block_acc = 0;
|
||||
|
||||
// unscaled accumulation within this K-group
|
||||
for(std::size_t k = k_begin; k < k_end; ++k)
|
||||
{
|
||||
const AccDataType v_a = load_a(k);
|
||||
const AccDataType v_b = load_b(k);
|
||||
v_block_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
const ck_tile::index_t k_group = static_cast<ck_tile::index_t>(k_begin / kGroupK);
|
||||
const float scale_a = load_scale_a(k_group);
|
||||
const float scale_b = load_scale_b(k_group);
|
||||
|
||||
v_acc += v_block_acc * scale_a * scale_b;
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
|
||||
@@ -412,32 +412,6 @@ struct QuantGemmKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
|
||||
@@ -25,14 +25,20 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr
|
||||
test_gemm_quant_aquant_base_ccr.cpp
|
||||
)
|
||||
# ABQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant
|
||||
test_gemm_quant_abquant.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_base
|
||||
test_gemm_quant_abquant_base.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_padding
|
||||
test_gemm_quant_abquant_padding.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
|
||||
test_gemm_quant_aquant_prefill.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for ABQuant padding padding tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPaddingTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigPadding, GroupSize, GroupSize, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant Padding
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPaddingTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 832, 832);
|
||||
}
|
||||
@@ -80,6 +80,10 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
|
||||
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool kPadM = GemmConfig::kPadM;
|
||||
static constexpr bool kPadN = GemmConfig::kPadN;
|
||||
static constexpr bool kPadK = GemmConfig::kPadK;
|
||||
|
||||
public:
|
||||
void SetUp() override { static_cast<Derived*>(this)->SetUpQuantTypeSpecific(); }
|
||||
|
||||
@@ -88,9 +92,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
// Common test execution logic
|
||||
void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
// WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize.
|
||||
// static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
|
||||
// VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only
|
||||
|
||||
@@ -83,6 +83,12 @@ struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPadding : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
|
||||
Reference in New Issue
Block a user