supporting prefill shapes for preshuffle block scale gemm (#2975)

* debugging

* debugging for prefill shapes

* comment unused code

* fix for prefill shapes

* clearing up the code

* add int4 to universal gemm example

* clang formatted

* adding test for prefill shapes in block scale gemm

* lil improv on the block pipeline

* Address Review Comment

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-10-10 15:36:24 -07:00
committed by GitHub
parent 9d060d3e3c
commit 3c39d279ab
10 changed files with 137 additions and 89 deletions

View File

@@ -75,6 +75,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "int4")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");

View File

@@ -194,10 +194,7 @@ struct WeightPreshuffleInvoker
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
throw std::runtime_error("split-k is not supported yet!");
}
};

View File

@@ -300,16 +300,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if(init_method == 0)
{
if constexpr(preshuffle)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
else if(init_method == 1)
{
@@ -353,6 +345,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
}
}();
// shuffled buffer B for device implementation
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::permute_vectors_i4x4_b(b_shuffle_host);
}
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
}
else

View File

@@ -4,8 +4,18 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
- Row and Column-wise scaled: scaling implemented in Epilogue
- Tensor-wise scaled: scaling implemented in Epilogue
- Row and Column-wise scaled: All of the rowwise elements in A Matrix and columwise elements in B Matrix will share the same quantization element and the elementwisde operation will complete in epilogue.
- Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B
---
## Features
- **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
- **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading
- **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
- **Validation**: CPU/GPU validation and error tolerance options.
## build
```

View File

@@ -47,6 +47,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
QuantMode,
ALayout, // for AQLayout
BLayout, // for BQLayout
false,
GemmConfig::DoubleSmemBuffer>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
@@ -450,4 +451,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv); }

View File

@@ -166,6 +166,26 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
@@ -261,7 +281,7 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -127,62 +127,72 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
BQBlockTensor& bq_block_tensor,
ABlockWindow& a_warp_windows) const
{
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstr = typename WG::CWarpDstr;
using AccTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
CWarpTensor c_warp_tensor;
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
statically_indexed_array<statically_indexed_array<AccTensor, NIterPerWarp>, MIterPerWarp>
c_acc;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
// warp GEMM
if constexpr(kIterInQScale == 0)
c_warp_tensor = WG{}(a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
else
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
__builtin_amdgcn_sched_barrier(0x7F6);
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
auto zero_accumulators = [&] {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) {
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
}); // make sure WG::CWarpTensor exposes a clear/zero
});
});
};
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
zero_accumulators();
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// warp GEMM
WG{}(c_acc(mIter)(nIter),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0x7F6);
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted
if constexpr((mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(merge_sequences(
sequence<number<0>{}, number<0>{}>{}, c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
constexpr index_t reg_offset = kQScale;
// nIter * KPerBlockBQ + kQScale; //((kIter * WG::kK) / kQuantGroupSize);
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
});
});
});
}

View File

@@ -1111,7 +1111,6 @@ struct QuantGemmKernel
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{

View File

@@ -57,26 +57,10 @@ struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase
static constexpr bool TransposeC = true;
};
struct GemmConfigPreshuffleB
struct GemmConfigPreshuffleBDecode : public GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
// Default GEMM tile sizes for tests
static constexpr ck_tile::index_t M_Tile = 16;
@@ -92,6 +76,25 @@ struct GemmConfigPreshuffleB
static constexpr ck_tile::index_t K_Warp_Tile = 64;
};
struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
{
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
// Default GEMM tile sizes for tests
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
};
template <typename Tuple>
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
{

View File

@@ -62,10 +62,15 @@ using BQuantTypes = ::testing::Types<
// clang-format off
using BPreshuffleBQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>
>;
// clang-format off