mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
[CK_TILE] Fix gemm_quant (#3186)
[ROCm/composable_kernel commit: 1b1c46e508]
This commit is contained in:
@@ -5,7 +5,7 @@ endif()
|
||||
|
||||
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# Typed Test Suite for GEMM Quantization
|
||||
add_gtest_executable(test_tile_gemm_quant_typed
|
||||
test_gemm_quant_typed.cpp
|
||||
|
||||
@@ -69,7 +69,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
// WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize.
|
||||
// static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
|
||||
// VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only
|
||||
// need to check the limitation on RDNA cards, i.e. assume wave size is 32.
|
||||
constexpr ck_tile::index_t WaveSize = 32;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
|
||||
constexpr bool SupportVectorSize16 =
|
||||
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
|
||||
constexpr int VectorSize = PreshuffleB ? (SupportVectorSize16 ? 16 : 8) : 16;
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
@@ -89,7 +97,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
ALayout,
|
||||
BLayout,
|
||||
GemmConfig::TransposeC,
|
||||
DoubleSmemBuffer>;
|
||||
DoubleSmemBuffer,
|
||||
false,
|
||||
VectorSize>;
|
||||
|
||||
// Let the derived class create the appropriate pipeline and epilogue
|
||||
static_cast<Derived*>(this)
|
||||
|
||||
@@ -7,6 +7,16 @@
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
|
||||
template <bool is_8bit>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
return is_8bit ? 64 : 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
@@ -40,7 +50,7 @@ struct GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
@@ -75,7 +85,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
|
||||
@@ -94,7 +104,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
|
||||
@@ -132,7 +142,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
|
||||
// AQuant uses grouped quantization for A matrix
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
|
||||
@@ -373,7 +383,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
|
||||
// BQuant uses block/grouped quantization for B matrix
|
||||
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN);
|
||||
@@ -629,7 +639,7 @@ class TestCkTileGemmRowColQuant
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
|
||||
// RowColQuant uses per-row and per-column scales
|
||||
const ck_tile::index_t stride_row_scales = 1;
|
||||
@@ -846,7 +856,7 @@ class TestCkTileGemmTensorQuant
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
|
||||
// TensorQuant uses single scalar scale for each tensor
|
||||
const ck_tile::index_t stride_scale_a = 1;
|
||||
|
||||
Reference in New Issue
Block a user