mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle (#2897)
* [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle When TransposeC and QuantPreshuffle are both true, Aquant generates correct result. * [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle - Add unit tests * Fix bug in is_quantpreshuffle_enabled * clang format --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -346,13 +346,40 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
if constexpr(Traits::TransposeC) // transposed C
|
||||
{
|
||||
static_assert(false,
|
||||
"It is not supported yet to enable both Preshuffle "
|
||||
"and TransposeC.");
|
||||
// TODO:
|
||||
// A new tile distribution is needed for the Preshuffle and
|
||||
// Transpose combination. For instance, with mnk at 16x16x32, lanes
|
||||
// 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ.
|
||||
constexpr auto tbuf_offset = number<
|
||||
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
||||
auto pull_from_lane = (__lane_id() & (Traits::WarpGemm::kN - 1)) *
|
||||
Traits::AQPerBlock +
|
||||
kQScale;
|
||||
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -73,7 +73,7 @@ struct is_quantpreshuffle_enabled
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_quantpreshuffle_enabled<T, decltype(T::PreshuffleQuant)>
|
||||
struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
|
||||
{
|
||||
static constexpr bool value = T::PreshuffleQuant;
|
||||
};
|
||||
|
||||
@@ -39,6 +39,7 @@ template <bool kPadM_,
|
||||
QuantType QuantType_,
|
||||
typename AQLayout_ = ALayout_,
|
||||
typename BQLayout_ = BLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool DoubleSmemBuffer_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
struct TileGemmQuantTraits
|
||||
@@ -62,7 +63,7 @@ struct TileGemmQuantTraits
|
||||
using AsLayout = ALayout_;
|
||||
using BsLayout = BLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
|
||||
Reference in New Issue
Block a user