mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
[CK_TILE] Row/Col quant gemm (#2729)
* Add cshuffle epilogue test * add the poc implementation to the epilogue and tests * refactor cshuffle epilogue * WIP: adding tensor/tile usage to scale_tile * fix usage of tile_elementwise_inout * add gemm_quant_kernel for generalizing gemm quant kernel * Add problem specific to different quants, add QuantType to Traits * Add quant_type to quant_kernel template parameters * Create aq/bq_block_windows and views depending on QuantType * Use tile windows as inputs in cshuffle epilogue * Fix some issues in epilogue * initial new example code for new general gemm quant kernel test * Fix issues in kernel * Add verification check for rowcol Quantmode * use AccDataType instead of AQ in pipeline * fix aquant preshuffle * fix formatting * some cleanup * remove gemm_aquant_basic.cpp * remove gemm_aquant_kernel.hpp * fix tests for the renamed quant kernel * fix formatting * clean example files * fix some merge conflicts * fix preshufflequant rename issue * fix some templates after merging with develop * fix test preshuffle parameter * fix formatting * Unify bquant kernel to the common quant kernel * remove bquant kernel also from common header * fix formatting * clean up commented code * fix formatting config hpp * fix merge mistake * Non-const for movable windows * fix formatting * Fix grammar in README Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Remove #include<bit> and clean up example * fix strides * Add some descriptions for move_windows --------- Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com> Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
@@ -240,4 +240,4 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
@@ -55,13 +55,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
false, // preshuffle
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -114,8 +115,10 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -185,7 +188,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::AQuantGemmHostArgs args;
|
||||
ck_tile::QuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
@@ -194,7 +197,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK = AQK;
|
||||
args.QK_A = AQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
Reference in New Issue
Block a user