diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index 2ea8530cb2..8141d99286 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -8,11 +8,10 @@ #include #include -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" #include "gemm_utils.hpp" -template + uint32_t QuantGroupSize> float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) { constexpr bool kPadM = false; @@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s static_assert(std::is_same_v); - constexpr ck_tile::index_t M_Tile = 16; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 256; + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 4; - constexpr ck_tile::index_t K_Warp = 1; + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; - constexpr ck_tile::index_t M_Warp_Tile = 16; - constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; using CodegenGemmShape = ck_tile::TileGemmShape, @@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = - ck_tile::TileGemmAQuantTraits; + using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase{}); + decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } else if(data_type == "bf8") { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = + decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } @@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[]) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); @@ -209,29 +213,15 @@ int run_gemm_example(int argc, char* argv[]) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } - else if(data_type == "i4f32fp8") - { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, TypeConfig, 128>( - a_layout, b_layout, argc, argv); - } - else if(data_type == "i4f32bf8") - { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, TypeConfig, 128>( - a_layout, b_layout, argc, argv); - } else { throw std::runtime_error("Unsupported data type for this operation !!!"); } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp index 4adc3df94b..0690c4884f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp @@ -8,11 +8,10 @@ #include #include -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" #include "gemm_utils.hpp" -template + uint32_t QuantGroupSize> float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) { constexpr bool kPadM = false; @@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s static_assert(std::is_same_v); - constexpr ck_tile::index_t M_Tile = 16; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 256; + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 4; - constexpr ck_tile::index_t K_Warp = 1; + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; - constexpr ck_tile::index_t M_Warp_Tile = 16; - constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; using CodegenGemmShape = ck_tile::TileGemmShape, @@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = - ck_tile::TileGemmAQuantTraits; + using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase{}); + decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } else if(data_type == "bf8") { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = + decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } @@ -200,7 +204,7 @@ int run_gemm_example(int argc, char* argv[]) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); @@ -209,29 +213,18 @@ int run_gemm_example(int argc, char* argv[]) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } - else if(data_type == "i4f32fp8") - { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, TypeConfig, 128>( - a_layout, b_layout, argc, argv); - } - else if(data_type == "i4f32bf8") - { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, TypeConfig, 128>( - a_layout, b_layout, argc, argv); - } else { throw std::runtime_error("Unsupported data type for this operation !!!"); } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + return !run_gemm_example(argc, argv); +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 0d0da93133..83a53e3c13 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -11,11 +11,9 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm_group_quant.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_PRESHUFFLE 5 +#define CK_TILE_PIPELINE_PREFILL 1 +#define CK_TILE_PIPELINE_DECODE 2 +#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3 template constexpr ck_tile::index_t get_k_warp_tile() @@ -87,196 +85,32 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; + static constexpr bool PreshuffleQuant = false; + static constexpr bool DoubleSmemBuffer = false; }; template -struct GemmConfigMemoryInterwave : public GemmConfigBase +struct GemmConfigDecode : public GemmConfigBase { - // Memory friendly for Interwave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 32; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; -}; - -template -struct GemmConfigMemoryIntrawave : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 32; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; -}; - -template -struct GemmConfigComputeV3 : public GemmConfigBase -{ - // Compute V3 only support Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 32; - static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; -}; - -template -struct GemmConfigComputeV3_1 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; -}; - -template -struct GemmConfigComputeV3_2 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - - static constexpr int kBlockPerCu = 2; -}; - -template -struct GemmConfigComputeV4 : public GemmConfigBase -{ - // Compute V4 only support Intrawave scheduler - // Using the ping pong reader in the lds level - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; -}; - -template -struct GemmConfigComputeV4_1 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; -}; - -template -struct GemmConfigComputeV5 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 2; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; -}; - -template -struct GemmConfigPreshufle_1 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); - - static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE; }; template -struct GemmConfigPreshufle_2 : public GemmConfigBase +struct GemmConfigPrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -288,18 +122,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL; }; template -struct GemmConfigPreshufle_AQ : public GemmConfigBase +struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; @@ -314,9 +145,9 @@ struct GemmConfigPreshufle_AQ : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_from_preshuffled_warp_tile(); - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = false; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT; + static constexpr bool PreshuffleQuant = true; }; template -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::half_t; - using QDataType = float; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf16_t; - using QDataType = float; - using BDataType = ck_tile::bf16_t; - using AccDataType = float; - using CDataType = ck_tile::bf16_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = float; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = float; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::half_t; - using QDataType = float; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = float; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = float; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = ck_tile::fp8_t; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = ck_tile::fp8_t; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = ck_tile::bf8_t; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = ck_tile::bf8_t; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = float; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = float; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = ck_tile::fp8_t; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = ck_tile::bf8_t; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = float; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = float; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - template struct DataTypeTraits; @@ -559,55 +220,6 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template -struct PipelineTypeTraits; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - template - using UniversalGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc index 6b5e01ca4c..8b045a2cf4 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc @@ -31,7 +31,8 @@ auto shuffle_aq(const ck_tile::HostTensor& t, int block_aq_k) return ck_tile::reference_permute(t_view, {1, 0, 2}); } -template + uint32_t QuantGroupSize> float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& aq_m_aqk_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, @@ -73,7 +73,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_C = stride_C; args.stride_AQ = stride_AQ; - float ave_time = gemm_calc_aquant( + QuantGroupSize>( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -206,7 +206,7 @@ int run_gemm_example_with_layouts(int argc, ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - if constexpr(GemmConfig::Preshuffle) + if constexpr(GemmConfig::PreshuffleQuant) { ck_tile::HostTensor aq_shuffle_host = shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize); @@ -222,7 +222,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, - aq_m_aqk_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - AQK, - stride_A, - stride_AQ, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + QuantGroupSize>(a_m_k_dev_buf, + aq_m_aqk_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + AQK, + stride_A, + stride_AQ, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index c6b8882946..d6921208c7 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -157,7 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase static constexpr index_t KPack = WarpGemm::kKPerThread; static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; - static constexpr bool Preshuffle = Problem::Traits::Preshuffle; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; }; public: @@ -357,7 +357,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase } }); - if constexpr(Traits::Preshuffle) + if constexpr(Traits::PreshuffleQuant) { // A view is created on top of the preshuffled AQ, where each row of the // view is composed of a row from a warp tile within an AQ block tile. @@ -392,12 +392,27 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0] // from thread 1, ..., and AQ_tile[3, 0] from thread 3. - auto pull_from_lane = - ((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN * - kTileRowsOfCPerThread + - c_row) * - Traits::QScalesPerBlockRow + - kQScale; + decltype(threadIdx.x) pull_from_lane = 0; + if constexpr(WarpGemm::kM == 16) + { + pull_from_lane = (__lane_id() / Traits::WarpGemm::kN * + kTileRowsOfCPerThread + + c_row) * + Traits::QScalesPerBlockRow + + kQScale; + } + else if constexpr(WarpGemm::kM == 32) + { + pull_from_lane = (__lane_id() / Traits::WarpGemm::kN * + kTileRowsOfCPerThread + + ((c_row >> 2) << 3) + (c_row & 0b11)) * + Traits::QScalesPerBlockRow + + kQScale; + } + else + { + static_assert(false, "WarpGemm::kM is not 16 nor 32."); + } auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter]; // cross lane ops diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp index 6973c80d57..49fbbfbc71 100644 --- a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp @@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs template struct AQuantGemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using AQLayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = GemmPipeline::BlockSize; - static constexpr bool Preshuffle = GemmPipeline::Preshuffle; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant; using ADataType = remove_cvref_t; using AQDataType = remove_cvref_t; @@ -422,9 +422,9 @@ struct AQuantGemmKernel ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); const auto aq_merge_pad1_desc = transform_tensor_descriptor( aq_pad1_desc, - make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)), + make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), make_pass_through_transform(pad_wave_size)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); return make_tensor_view(aq_ptr, aq_merge_pad1_desc); @@ -432,7 +432,7 @@ struct AQuantGemmKernel const auto& aq_tensor_view = [&]() { static_assert(std::is_same_v); - if constexpr(Preshuffle) + if constexpr(PreshuffleQuant) { return make_preshuffled_aq_tensor_view(); } @@ -599,10 +599,8 @@ struct AQuantGemmKernel } template - CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, - const AQuantGemmKernelArgs& kargs, - const index_t i_m, - const index_t i_n) + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { const auto& a_pad_view = views.at(I0); const auto& aq_pad_view = views.at(I1); @@ -628,24 +626,27 @@ struct AQuantGemmKernel const auto& aq_block_window = [&]() { static_assert(std::is_same_v); - if constexpr(Preshuffle) + constexpr auto block_m = TilePartitioner::MPerBlock; + constexpr auto block_k = TilePartitioner::KPerBlock; + constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0); + constexpr auto aqk_per_block = + TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize; + if constexpr(PreshuffleQuant) { - constexpr auto tile_window_width = get_warp_size(); - constexpr auto tile_window_height = - TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0); - auto block_m_idx = i_m / TilePartitioner::MPerBlock; + constexpr auto tile_window_width = + ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); + constexpr auto tile_window_height = block_m / warp_m; + auto block_m_idx = i_m / block_m; return make_tile_window( aq_pad_view, make_tuple(number{}, number{}), - {block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2), - 0}); + {block_m_idx * tile_window_height, 0}); } else { return make_tile_window( aq_pad_view, - make_tuple(number{}, - number{}), + make_tuple(number{}, number{}), {i_m, 0}); } }(); @@ -706,8 +707,7 @@ struct AQuantGemmKernel a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -718,7 +718,7 @@ struct AQuantGemmKernel const auto& b_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0); + a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 1fb92ad14d..c1fdeefc0c 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -37,23 +37,23 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using AQLayout = remove_cvref_t; using BlockGemmShape = typename Problem::BlockGemmShape; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; - constexpr index_t VecLoadSize = GetVectorSizeAQ(); - constexpr bool Preshuffle = Problem::Traits::Preshuffle; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmDispatcher; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + constexpr index_t VecLoadSize = GetVectorSizeAQ(); + constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; static_assert(std::is_same_v); - if constexpr(Preshuffle) + if constexpr(PreshuffleQuant) { using TileEncodingPattern = TileDistributionEncodingPatternAQ; + PreshuffleQuant>; return TileEncodingPattern::Make2DStaticTileDistribution(); } @@ -77,7 +77,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC KPerBlockAQ, KPerBlockAQ, VecLoadSize, - Preshuffle>; + PreshuffleQuant>; return TileEncodingPattern::Make2DStaticTileDistribution(); } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 64b2402aa5..037cef0553 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -7,6 +7,7 @@ #include #include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" @@ -133,7 +134,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr bool Preshuffle = Problem::Traits::Preshuffle; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; @@ -235,6 +236,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t m, index_t num_loop, void* p_smem) const { @@ -311,9 +313,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); // only row_major for AQ - constexpr AQDramTileWindowStep aq_dram_tile_window_step = - Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0) - : make_array(0, KPerBlockAQ); + const AQDramTileWindowStep aq_dram_tile_window_step = + PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) / + BlockGemm::WarpGemm::kM, + 0) + : make_array(0, KPerBlockAQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); @@ -458,6 +462,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t m, index_t num_loop, void* p_smem) const { @@ -467,6 +472,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, + m, num_loop, p_smem); } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp index 051543b8b6..99c8762366 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp @@ -52,7 +52,7 @@ template + bool PreshuffleQuant> struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -72,20 +72,20 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() { - if constexpr(Preshuffle) + if constexpr(PreshuffleQuant) { // # of elements per thread - constexpr index_t X2 = KPerBlockAQ; - constexpr index_t X1 = warp_size / X2; + static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0); + constexpr index_t X1 = warp_size; constexpr index_t X0 = XPerTile / warp_size; constexpr index_t Y1 = MWarps; constexpr index_t Y0 = YPerTile / Y1; return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 2>>, - tuple, sequence<1, 2>>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, sequence<1, 2>, sequence<0, 0>>{}); } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp index 41f8f1deef..fe96c28f33 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp @@ -10,7 +10,7 @@ namespace ck_tile { template constexpr ck_tile::index_t get_k_warp_tile() @@ -34,21 +32,6 @@ constexpr ck_tile::index_t get_k_warp_tile() return 32; #endif } -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(__gfx950__) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} template auto calculate_rtol_atol(const ck_tile::index_t K, @@ -93,195 +76,32 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; + static constexpr bool PreshuffleQuant = false; + static constexpr bool DoubleSmemBuffer = true; }; template -struct GemmConfigMemoryInterwave : public GemmConfigBase +struct GemmConfigDecode : public GemmConfigBase { - // Memory friendly for Interwave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 32; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; -}; - -template -struct GemmConfigMemoryIntrawave : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 32; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; -}; - -template -struct GemmConfigComputeV3 : public GemmConfigBase -{ - // Compute V3 only support Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 32; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 256; + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; -}; - -template -struct GemmConfigComputeV3_1 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; -}; - -template -struct GemmConfigComputeV3_2 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - - static constexpr int kBlockPerCu = 2; -}; - -template -struct GemmConfigComputeV4 : public GemmConfigBase -{ - // Compute V4 only support Intrawave scheduler - // Using the ping pong reader in the lds level - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; -}; - -template -struct GemmConfigComputeV4_1 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; -}; - -template -struct GemmConfigComputeV5 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 2; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; -}; - -template -struct GemmConfigPreshufle_1 : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - - static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE; }; template -struct GemmConfigPreshufle_2 : public GemmConfigBase +struct GemmConfigPrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -293,71 +113,32 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL; }; -template -struct GemmTypeConfig; - -template <> -struct GemmTypeConfig +template +struct GemmConfigPreshuffleQuant : public GemmConfigBase { - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; - // ToDo: Add more bias config to support different categories of GEMM. -}; + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::bf16_t; - using BDataType = ck_tile::bf16_t; - using AccDataType = float; - using CDataType = ck_tile::bf16_t; -}; + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + get_k_from_preshuffled_warp_tile(); -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmTypeConfig -{ - using ADataType = ck_tile::int8_t; - using BDataType = ck_tile::int8_t; - using AccDataType = int32_t; - using CDataType = int32_t; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT; + static constexpr bool PreshuffleQuant = true; }; template -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::half_t; - using QDataType = float; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf16_t; - using QDataType = float; - using BDataType = ck_tile::bf16_t; - using AccDataType = float; - using CDataType = ck_tile::bf16_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = float; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = float; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::half_t; - using QDataType = float; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = float; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = float; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = ck_tile::fp8_t; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = ck_tile::fp8_t; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = ck_tile::bf8_t; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = ck_tile::bf8_t; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = float; - using BDataType = ck_tile::fp8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::pk_int4_t; - using QDataType = float; - using BDataType = ck_tile::bf8_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = ck_tile::fp8_t; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = ck_tile::bf8_t; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::fp8_t; - using QDataType = float; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = float; -}; - -template <> -struct GemmQuantTypeConfig -{ - using ADataType = ck_tile::bf8_t; - using QDataType = float; - using BDataType = ck_tile::pk_int4_t; - using AccDataType = float; - using CDataType = float; -}; - template struct DataTypeTraits; @@ -600,55 +211,6 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template -struct PipelineTypeTraits; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; - template - using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; -}; - -template <> -struct PipelineTypeTraits -{ - template - using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - template - using UniversalGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc index e8ff45fc5e..3439309857 100644 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -15,7 +15,8 @@ #include "ck_tile/host.hpp" #include "test_gemm_aquant_utils.hpp" -template + uint32_t QuantGroupSize> float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) { constexpr bool kPadM = false; @@ -36,17 +36,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s static_assert(std::is_same_v); - constexpr ck_tile::index_t M_Tile = 16; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 256; + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 4; - constexpr ck_tile::index_t K_Warp = 1; + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; - constexpr ck_tile::index_t M_Warp_Tile = 16; - constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; using CodegenGemmShape = ck_tile::TileGemmShape, @@ -55,8 +55,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = - ck_tile::TileGemmAQuantTraits; + using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase>{}; } -template + uint32_t QuantGroupSize> float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& aq_m_aqk_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, @@ -194,7 +199,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_C = stride_C; args.stride_AQ = stride_AQ; - float ave_time = gemm_calc_aquant( + QuantGroupSize>( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -227,7 +232,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, return ave_time; } -template +template bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -412,7 +419,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg { if(a_layout == "R" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( argc, argv, Row{}, Row{}, Col{}, Row{}); } else @@ -428,6 +435,7 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg return true; } +template