diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index bc7bcce34c..ed7cadc41c 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -67,7 +67,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::BQuantGrouped; @@ -82,7 +81,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmShape, GemmUniversalTraits, QuantGroupSize, - transpose_c>, + GemmConfig::TransposeC>, ck_tile::GemmBQuantPipelineProblem>; @@ -161,7 +160,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, int main(int argc, char* argv[]) { - int result1 = run_grouped_gemm_example(argc, argv); /* || - run_grouped_gemm_example(argc, argv);*/ + int result1 = run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 554584c2bd..eb8ccd86b9 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -102,6 +102,24 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfig_Aquant : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool TransposeC = true; +}; + template struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { @@ -118,10 +136,42 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_from_preshuffled_warp_tile(); + static constexpr bool TransposeC = false; static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmQuantConfig; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfig_Aquant; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; +}; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index f2d7f628a5..3628247f44 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -533,12 +533,13 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - using Types = GemmTypeConfig; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using GemmConfig = GemmQuantConfig::template GemmConfig; + using Types = GemmTypeConfig; // Specific type aliases for easy access using ADataType = typename Types::ADataType; using BDataType = typename Types::BDataType; @@ -567,7 +568,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } -template