[CK_TILE] Fix gemm_quant (#3186)

This commit is contained in:
linqunAMD
2025-11-12 00:23:57 +08:00
committed by GitHub
parent 88e3212fcc
commit 1b1c46e508
13 changed files with 135 additions and 49 deletions

View File

@@ -82,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
{
scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
{
scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, float>)
{

View File

@@ -25,13 +25,11 @@ struct BlockGemmAQuantBase
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<AQDataType, float>)
{
@@ -349,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
// 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
constexpr uint32_t kTileRowsOfCPerThread = 4;
constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8;
decltype(threadIdx.x) pull_from_lane = 0;
if constexpr(WarpGemm::kM == 16)
{
@@ -410,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
// desired row coefficient
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
// Multiply by 4 because output is stored in tiles of 4

View File

@@ -25,13 +25,11 @@ struct BlockGemmBQuantBase
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
}
else if constexpr(std::is_same_v<BQDataType, float>)
{