mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_TILE] Fix gemm_quant (#3186)
This commit is contained in:
@@ -82,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
|
||||
@@ -25,13 +25,11 @@ struct BlockGemmAQuantBase
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
@@ -349,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
|
||||
// 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
|
||||
constexpr uint32_t kTileRowsOfCPerThread = 4;
|
||||
constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8;
|
||||
decltype(threadIdx.x) pull_from_lane = 0;
|
||||
if constexpr(WarpGemm::kM == 16)
|
||||
{
|
||||
@@ -410,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
// desired row coefficient
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
|
||||
constexpr uint32_t kTileRows = 4;
|
||||
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
|
||||
;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
|
||||
@@ -25,13 +25,11 @@ struct BlockGemmBQuantBase
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
|
||||
@@ -240,7 +240,10 @@ struct QuantGemmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr QuantGemmKernelArgs
|
||||
MakeKernelArgs(const QuantGemmHostArgs& hostArgs)
|
||||
|
||||
@@ -41,7 +41,8 @@ template <bool kPadM_,
|
||||
typename BQLayout_ = BLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool DoubleSmemBuffer_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
bool UsePersistentKernel_ = false,
|
||||
int VectorSize_ = 16>
|
||||
struct TileGemmQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -50,7 +51,7 @@ struct TileGemmQuantTraits
|
||||
|
||||
static constexpr QuantType kQuantType = QuantType_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr int _VectorSize = VectorSize_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
|
||||
Reference in New Issue
Block a user