feat(grouped_gemm): add preshuffle v2 support to grouped gemm example (#2721)

* docs(README): update readme with new build instructions

* feat(grouped_gemm): add support back for non persistent kernel

* refactor(grouped_gemm): simplify tensor creation

* refactor(grouped_gemm): Persistance is now GemmConfig value for easier management

* chore(grouped_gemm): add print statements to ease debugging

* WIP(grouped_gemm): add grouped_gemm_preshuffle example and update CMake configuration

* fix(tile_gemm_traits): change default value of Preshuffle_ from 0 to false for clarity

* WIP(grouped_gemm): add dummy variables to compile the preshuffle pipelines

* chore(grouped_gemm): add print statements and variables to debug numerical error with preshuffle

* style: clang format work so far

* BUG!(grouped_gemm_kernel.hpp): figured out a potential bug in for numerical errors in preshuffle pipeline

* fix(grouped_gemm_kernel): add function in the kernel code to dynamically calculate tail_number resolving numerical errors

* refactor(gemm_presuffle): make preshuffle pipeline v2 compatible with operator () calls from grouped gemm

* chore(grouped_gemm): add/remove debug comments and debug print statements

* feat(grouped_gemm): integrate preshuffle pipeline v2 into grouped gemm for all supported shapes

* chore(gemm_profile): add new argument combinations

* fix: branch cleanup, formatting, refactoring

* fix: branch cleanup, formatting, refactoring

* chore(changelog):  update changelog to reflect new featuer

* address review comments & nit
This commit is contained in:
Aviral Goel
2025-09-07 17:18:35 -04:00
committed by GitHub
parent 5224d2ead3
commit e279e9420e
13 changed files with 808 additions and 178 deletions

View File

@@ -33,15 +33,14 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
@@ -74,9 +73,6 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);