[CK_TILE] Fix gemm_quant (#3186)

[ROCm/composable_kernel commit: 1b1c46e508]
This commit is contained in:
linqunAMD
2025-11-12 00:23:57 +08:00
committed by GitHub
parent c1b5372db3
commit 13cf0bd17f
13 changed files with 135 additions and 49 deletions

View File

@@ -5,7 +5,7 @@ endif()
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# Typed Test Suite for GEMM Quantization
add_gtest_executable(test_tile_gemm_quant_typed
test_gemm_quant_typed.cpp

View File

@@ -69,7 +69,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
// WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize.
// static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
// VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only
// need to check the limitation on RDNA cards, i.e. assume wave size is 32.
constexpr ck_tile::index_t WaveSize = 32;
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
constexpr bool SupportVectorSize16 =
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
constexpr int VectorSize = PreshuffleB ? (SupportVectorSize16 ? 16 : 8) : 16;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
@@ -89,7 +97,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
ALayout,
BLayout,
GemmConfig::TransposeC,
DoubleSmemBuffer>;
DoubleSmemBuffer,
false,
VectorSize>;
// Let the derived class create the appropriate pipeline and epilogue
static_cast<Derived*>(this)

View File

@@ -7,6 +7,16 @@
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
template <bool is_8bit>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
return is_8bit ? 64 : 32;
#endif
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;
@@ -40,7 +50,7 @@ struct GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
};
struct GemmConfigPreshuffleQuant : public GemmConfigBase
@@ -75,7 +85,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
};
struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
@@ -94,7 +104,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
};
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
@@ -132,7 +142,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// AQuant uses grouped quantization for A matrix
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
@@ -373,7 +383,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// BQuant uses block/grouped quantization for B matrix
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN);
@@ -629,7 +639,7 @@ class TestCkTileGemmRowColQuant
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// RowColQuant uses per-row and per-column scales
const ck_tile::index_t stride_row_scales = 1;
@@ -846,7 +856,7 @@ class TestCkTileGemmTensorQuant
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// TensorQuant uses single scalar scale for each tensor
const ck_tile::index_t stride_scale_a = 1;