[CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle (#2897)

* [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle

When TransposeC and QuantPreshuffle are both true, Aquant generates
correct result.

* [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle

- Add unit tests

* Fix bug in is_quantpreshuffle_enabled

* clang format

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-10-02 12:13:51 -06:00
committed by GitHub
parent a4ab33f539
commit 6fc28ab493
7 changed files with 109 additions and 15 deletions

View File

@@ -346,13 +346,40 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
{
if constexpr(Traits::TransposeC) // transposed C
{
static_assert(false,
"It is not supported yet to enable both Preshuffle "
"and TransposeC.");
// TODO:
// A new tile distribution is needed for the Preshuffle and
// Transpose combination. For instance, with mnk at 16x16x32, lanes
// 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ.
constexpr auto tbuf_offset = number<
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
auto pull_from_lane = (__lane_id() & (Traits::WarpGemm::kN - 1)) *
Traits::AQPerBlock +
kQScale;
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] *
scale_reg_f);
});
}
else
{

View File

@@ -73,7 +73,7 @@ struct is_quantpreshuffle_enabled
};
template <typename T>
struct is_quantpreshuffle_enabled<T, decltype(T::PreshuffleQuant)>
struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
{
static constexpr bool value = T::PreshuffleQuant;
};

View File

@@ -39,6 +39,7 @@ template <bool kPadM_,
QuantType QuantType_,
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_,
bool TransposeC_ = false,
bool DoubleSmemBuffer_ = false,
bool UsePersistentKernel_ = false>
struct TileGemmQuantTraits
@@ -62,7 +63,7 @@ struct TileGemmQuantTraits
using AsLayout = ALayout_;
using BsLayout = BLayout_;
static constexpr bool TransposeC = false;
static constexpr bool TransposeC = TransposeC_;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;