Weight Preshuffle Block Scale gemm support (#2877)

* initial commit

* remove extra files

* fixing errors

* updated ReadMe file for mapping of diff quants with diff configs

* addressing review comments

* addressing review comments

* Resolved merge conflicts

* [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled

The get_preshuffle_or was not working as expected, which led to incorrect behavior
in the quantization preshuffle process. This change replaces it with the more reliable
is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied.

* initial commit

* debugging

* working fp8 for init constant

* fp8 working with all inits

* updated block level code with comments

* changing the loop iter

* debugging

* debugging

* debugging

* code fix

* code clean up

* clang formatted

* Add comment

* code cleanup

* clang formatted

* merge conflicts fixes

* applying the latest int4 changes to the piepline

* fixing test code for updated traits

* Adding gtest

* review comments addressed

* addressing review comments

* remove c++20 code

* added flush cache changes

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: root <root@banff-cyxtera-s73-2.ctr.dcgpu>

[ROCm/composable_kernel commit: 81458a6681]
This commit is contained in:
Khushbu Agarwal
2025-09-29 12:46:37 -07:00
committed by GitHub
parent 47b8632296
commit 7c20b1f690
17 changed files with 1129 additions and 53 deletions

View File

@@ -0,0 +1,191 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
namespace ck_tile {
// A is block window on shared memory
// BQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// B is block window on block distributed tensor.
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_>
struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
{
using Problem = remove_cvref_t<Problem_>;
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto idxM = I0;
static constexpr auto idxN = I1;
static constexpr auto idxK = I2;
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
static constexpr auto warp_size = get_warp_size();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
static constexpr index_t KPerBlockBQ = KPerBlock / kQuantGroupSize;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
static constexpr index_t QScalesPerWarpGemmRow =
(WG::kK + kQuantGroupSize - 1) / kQuantGroupSize;
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
template <typename T>
CK_TILE_DEVICE static float cvt_scale_to_fp32(T& scale)
{
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
{
scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
{
scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
}
else if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_f = ck_tile::bit_cast<float>(scale);
}
else
{
static_assert(false, "BQDataType must be float, fp8_t or bf8_t.");
}
return scale_reg_f;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C += A * B
template <typename CBlockTensor,
typename ABlockTensor,
typename BFlatBlockTensor,
typename BQBlockTensor,
typename ABlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
ABlockTensor& a_warp_tensor,
BFlatBlockTensor& b_warp_tensor,
BQBlockTensor& bq_block_tensor,
ABlockWindow& a_warp_windows) const
{
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
CWarpTensor c_warp_tensor;
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
// warp GEMM
if constexpr(kIterInQScale == 0)
c_warp_tensor = WG{}(a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
else
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
__builtin_amdgcn_sched_barrier(0x7F6);
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
});
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(merge_sequences(
sequence<number<0>{}, number<0>{}>{}, c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr index_t reg_offset = kQScale;
// nIter * KPerBlockBQ + kQScale; //((kIter * WG::kK) / kQuantGroupSize);
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
}
};
} // namespace ck_tile

View File

@@ -344,11 +344,11 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
if constexpr(Traits::PreshuffleQuant)
{
static_assert(false,
"It is not supported yet to enable both Preshuffle and "
"TransposeC.");
if constexpr(Traits::TransposeC) // transposed C
{
static_assert(false,
"It is not supported yet to enable both Preshuffle "
"and TransposeC.");
// TODO:
// A new tile distribution is needed for the Preshuffle and
// Transpose combination. For instance, with mnk at 16x16x32, lanes