// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_PRESHUFFLE_V2 4 template constexpr ck_tile::index_t get_k_warp_tile() { #if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) return is_8bit_float ? 64 : 16; else return is_8bit_float ? 128 : 32; #else if constexpr(M_Warp_Tile == 32) return 16; else return 32; #endif } template constexpr ck_tile::index_t get_k_warp_tile_flatmm() { #if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; else return sizeof(PrecType) == 2 ? 32 : 128; #else if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 32; else return sizeof(PrecType) == 2 ? 32 : 64; #endif } template struct GemmTypeConfig; template <> struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using CDataType = ck_tile::half_t; using AccDataType = float; }; template <> struct GemmTypeConfig { using ADataType = ck_tile::fp8_t; using BDataType = ck_tile::fp8_t; using AccDataType = float; using CDataType = ck_tile::half_t; }; struct GemmConfigBase { static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = false; }; template struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr int kBlockPerCu = 1; }; template struct GemmConfigComputeV4 : public GemmConfigBase { // Compute V4 only support Intrawave scheduler // Using the ping pong reader in the lds level static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; template struct GemmConfigComputeV4_V2 : public GemmConfigBase { // Compute V4 only support Intrawave scheduler // Using the ping pong reader in the lds level static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; template struct GemmConfigPreshuffleDecode : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr bool kPadK = true; static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; }; template struct GemmConfigPreshufflePrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; static constexpr bool kPadK = true; }; template struct GemmConfigComputeV4_Wmma : public GemmConfigBase { // Compute V4 only support Intrawave scheduler // Using the ping pong reader in the lds level static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; template struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 32 / sizeof(PrecType); static constexpr ck_tile::index_t N_Tile = 64; static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool kPadK = true; static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = true; }; template struct PipelineTypeTraits; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; template using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; }; template <> struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; template using UniversalGemmPipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; std::pair create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("Ms", "", "M dimensions - empty by default.") .insert("Ns", "", "N dimensions - empty by default.") .insert("Ks", "", "K dimensions - empty by default.") .insert("stride_As", "", "Tensor A strides - it is empty by default.") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("a_layout", "R", "A tensor data layout - Row by default.") .insert("b_layout", "C", "B tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "grouped_gemm.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_pair(result, arg_parser); } inline std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; if(ck_tile::is_gfx12_supported()) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, GemmConfig::N_Warp_Tile, k_ / GemmConfig::K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } else { int divisor = 1; if(ck_tile::is_gfx11_supported()) { divisor = 1; } else { assert(is_wave32() == false); divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; } ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, GemmConfig::N_Warp_Tile, k_ / GemmConfig::K_Warp_Tile, divisor, GemmConfig::K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } } template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr); template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, bool splitk = false);