diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp old mode 100644 new mode 100755 index 897952f03c..a821af0649 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,91 +16,50 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -template +template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, bool splitk) { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - constexpr bool DoubleSmemBuffer = false; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = false; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = true; -#endif - constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; // We create the GEMM pipeline without specifying hotloop or tailnumber. @@ -112,7 +71,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmUniversalTraits, scheduler>; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; @@ -145,7 +105,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -173,4 +133,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, #include "run_grouped_gemm_example.inc" constexpr bool Persistent = true; -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + return !run_grouped_gemm_example(argc, argv); +} diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 89d91fbef6..e992cb3118 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -15,24 +15,26 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4 +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; #else -#error "unsupported CK_TILE_PIPELINE_DEFAULT value" + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; #endif +} template struct GemmTypeConfig; @@ -46,13 +48,109 @@ struct GemmTypeConfig using AccDataType = float; }; -using Types = GemmTypeConfig; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 1; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; @@ -69,6 +167,7 @@ auto create_args(int argc, char* argv[]) .insert("b_layout", "C", "B tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") @@ -98,7 +197,14 @@ float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr); -template +template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index fa7f1a31c1..425299203f 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -10,6 +10,7 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -30,7 +31,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template ( - stream, group_count, kargs_ptr, splitk); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr, splitk); } std::string op_name{"Grouped Gemm"}; @@ -127,7 +135,15 @@ float invoke_gemm(int n_warmup, return ave_time; } -template +template int run_grouped_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -243,7 +259,8 @@ int run_grouped_gemm_example_with_layouts(int argc, {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } - invoke_gemm, AccDataType, @@ -271,7 +288,9 @@ int run_grouped_gemm_example_with_layouts(int argc, a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref, "Error: Incorrect results!", @@ -288,7 +307,61 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + + if(a_layout == "R" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +template typename GemmConfig> int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -297,30 +370,22 @@ int run_grouped_gemm_example(int argc, char* argv[]) return -1; } - const std::string a_layout = arg_parser.get_str("a_layout"); - const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string data_type = arg_parser.get_str("prec"); - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if(a_layout == "R" && b_layout == "C") + if(data_type == "fp16") { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + return run_gemm_example_prec_type, ck_tile::half_t>( + a_layout, b_layout, argc, argv); } - else if(a_layout == "R" && b_layout == "R") + else if(data_type == "fp8") { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + return run_gemm_example_prec_type, ck_tile::fp8_t>( + a_layout, b_layout, argc, argv); } else { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + throw std::runtime_error("Unsupported data type configuration."); } }