[CK_TILE] Fix gemm_quant (#3186)

This commit is contained in:
linqunAMD
2025-11-12 00:23:57 +08:00
committed by GitHub
parent 88e3212fcc
commit 1b1c46e508
13 changed files with 135 additions and 49 deletions

View File

@@ -5,7 +5,7 @@ endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f)
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigBQuantPrefill_Wmma>(argc, argv);
#else
// Use non-preshuffled GemmConfig for 2D block scale support
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
#endif
}

View File

@@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,

View File

@@ -24,16 +24,43 @@ template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>
@@ -55,21 +82,46 @@ template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
}
} // namespace ck_tile

View File

@@ -79,6 +79,7 @@ struct WarpGemmAttributeWmma
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kCMLane = Impl::kCMLane;
static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }

View File

@@ -82,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
{
scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
{
scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, float>)
{

View File

@@ -25,13 +25,11 @@ struct BlockGemmAQuantBase
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<AQDataType, float>)
{
@@ -349,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
// 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
constexpr uint32_t kTileRowsOfCPerThread = 4;
constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8;
decltype(threadIdx.x) pull_from_lane = 0;
if constexpr(WarpGemm::kM == 16)
{
@@ -410,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
// desired row coefficient
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
// Multiply by 4 because output is stored in tiles of 4

View File

@@ -25,13 +25,11 @@ struct BlockGemmBQuantBase
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, float>)
{

View File

@@ -240,7 +240,10 @@ struct QuantGemmKernel
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr QuantGemmKernelArgs
MakeKernelArgs(const QuantGemmHostArgs& hostArgs)

View File

@@ -41,7 +41,8 @@ template <bool kPadM_,
typename BQLayout_ = BLayout_,
bool TransposeC_ = false,
bool DoubleSmemBuffer_ = false,
bool UsePersistentKernel_ = false>
bool UsePersistentKernel_ = false,
int VectorSize_ = 16>
struct TileGemmQuantTraits
{
static constexpr bool kPadM = kPadM_;
@@ -50,7 +51,7 @@ struct TileGemmQuantTraits
static constexpr QuantType kQuantType = QuantType_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = VectorSize_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using ALayout = ALayout_;

View File

@@ -5,7 +5,7 @@ endif()
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# Typed Test Suite for GEMM Quantization
add_gtest_executable(test_tile_gemm_quant_typed
test_gemm_quant_typed.cpp

View File

@@ -69,7 +69,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
// WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize.
// static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
// VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only
// need to check the limitation on RDNA cards, i.e. assume wave size is 32.
constexpr ck_tile::index_t WaveSize = 32;
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
constexpr bool SupportVectorSize16 =
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
constexpr int VectorSize = PreshuffleB ? (SupportVectorSize16 ? 16 : 8) : 16;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
@@ -89,7 +97,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
ALayout,
BLayout,
GemmConfig::TransposeC,
DoubleSmemBuffer>;
DoubleSmemBuffer,
false,
VectorSize>;
// Let the derived class create the appropriate pipeline and epilogue
static_cast<Derived*>(this)

View File

@@ -7,6 +7,16 @@
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
template <bool is_8bit>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
return is_8bit ? 64 : 32;
#endif
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;
@@ -40,7 +50,7 @@ struct GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
};
struct GemmConfigPreshuffleQuant : public GemmConfigBase
@@ -75,7 +85,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
};
struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
@@ -94,7 +104,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
};
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
@@ -132,7 +142,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// AQuant uses grouped quantization for A matrix
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
@@ -373,7 +383,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// BQuant uses block/grouped quantization for B matrix
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN);
@@ -629,7 +639,7 @@ class TestCkTileGemmRowColQuant
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// RowColQuant uses per-row and per-column scales
const ck_tile::index_t stride_row_scales = 1;
@@ -846,7 +856,7 @@ class TestCkTileGemmTensorQuant
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
const ck_tile::index_t stride_C = N;
// TensorQuant uses single scalar scale for each tensor
const ck_tile::index_t stride_scale_a = 1;