mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK_TILE] Fix gemm_quant (#3186)
This commit is contained in:
@@ -82,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
|
||||
@@ -25,13 +25,11 @@ struct BlockGemmAQuantBase
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
@@ -349,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
|
||||
// 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
|
||||
constexpr uint32_t kTileRowsOfCPerThread = 4;
|
||||
constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8;
|
||||
decltype(threadIdx.x) pull_from_lane = 0;
|
||||
if constexpr(WarpGemm::kM == 16)
|
||||
{
|
||||
@@ -410,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
// desired row coefficient
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
|
||||
constexpr uint32_t kTileRows = 4;
|
||||
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
|
||||
;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
|
||||
@@ -25,13 +25,11 @@ struct BlockGemmBQuantBase
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user