[CK_TILE] Row/Col quant gemm (#2729)

* Add cshuffle epilogue test

* add the poc implementation to the epilogue and tests

* refactor cshuffle epilogue

* WIP: adding tensor/tile usage to scale_tile

* fix usage of tile_elementwise_inout

* add gemm_quant_kernel for generalizing gemm quant kernel

* Add problem specific to different quants, add QuantType to Traits

* Add quant_type to quant_kernel template parameters

* Create aq/bq_block_windows and views depending on QuantType

* Use tile windows as inputs in cshuffle epilogue

* Fix some issues in epilogue

* initial new example code for new general gemm quant kernel test

* Fix issues in kernel

* Add verification check for rowcol Quantmode

* use AccDataType instead of AQ in pipeline

* fix aquant preshuffle

* fix formatting

* some cleanup

* remove gemm_aquant_basic.cpp

* remove gemm_aquant_kernel.hpp

* fix tests for the renamed quant kernel

* fix formatting

* clean example files

* fix some merge conflicts

* fix preshufflequant rename issue

* fix some templates after merging with develop

* fix test preshuffle parameter

* fix formatting

* Unify bquant kernel to the common quant kernel

* remove bquant kernel also from common header

* fix formatting

* clean up commented code

* fix formatting config hpp

* fix merge mistake

* Non-const for movable windows

* fix formatting

* Fix grammar in README

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Remove #include<bit> and clean up example

* fix strides

* Add some descriptions for move_windows

---------

Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
Sami Remes
2025-09-05 02:17:12 +03:00
committed by GitHub
parent 7330ec37ee
commit c6010f2953
23 changed files with 1837 additions and 1331 deletions

View File

@@ -240,4 +240,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -26,7 +26,7 @@ template <typename GemmConfig,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -55,13 +55,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
false, // preshuffle
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::AQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -114,8 +115,10 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -185,7 +188,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::AQuantGemmHostArgs args;
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
@@ -194,7 +197,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.M = M;
args.N = N;
args.K = K;
args.QK = AQK;
args.QK_A = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;