[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* addressing review comments

* fixing CI issue

* addressing reveiw comments

* formatting

* formatting

* fixing aquant operator overlaoding

* formatting

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-11-24 07:48:42 -08:00
committed by GitHub
parent e857e26bf6
commit 8111572785
31 changed files with 855 additions and 247 deletions

View File

@@ -54,6 +54,8 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
@@ -172,16 +174,47 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
if constexpr(PreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
}
else
{
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
}
});
});
});

View File

@@ -274,7 +274,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
return Base::cvt_scale_to_fp32(gathered_scale_reg);
}
@@ -368,7 +367,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
}
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
return exchange_quant_value_across_lanes(scale_reg, pull_from_lane);
}
else
@@ -511,6 +509,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});

View File

@@ -100,6 +100,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
@@ -173,6 +175,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
static constexpr auto a_warp_y_lengths =
@@ -321,31 +325,65 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
}
});
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
if constexpr(PreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
}
else
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >=
(NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
}
});
});
});