mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
feat(grouped_gemm): add preshuffle v2 support to grouped gemm example (#2721)
* docs(README): update readme with new build instructions * feat(grouped_gemm): add support back for non persistent kernel * refactor(grouped_gemm): simplify tensor creation * refactor(grouped_gemm): Persistance is now GemmConfig value for easier management * chore(grouped_gemm): add print statements to ease debugging * WIP(grouped_gemm): add grouped_gemm_preshuffle example and update CMake configuration * fix(tile_gemm_traits): change default value of Preshuffle_ from 0 to false for clarity * WIP(grouped_gemm): add dummy variables to compile the preshuffle pipelines * chore(grouped_gemm): add print statements and variables to debug numerical error with preshuffle * style: clang format work so far * BUG!(grouped_gemm_kernel.hpp): figured out a potential bug in for numerical errors in preshuffle pipeline * fix(grouped_gemm_kernel): add function in the kernel code to dynamically calculate tail_number resolving numerical errors * refactor(gemm_presuffle): make preshuffle pipeline v2 compatible with operator () calls from grouped gemm * chore(grouped_gemm): add/remove debug comments and debug print statements * feat(grouped_gemm): integrate preshuffle pipeline v2 into grouped gemm for all supported shapes * chore(gemm_profile): add new argument combinations * fix: branch cleanup, formatting, refactoring * fix: branch cleanup, formatting, refactoring * chore(changelog): update changelog to reflect new featuer * address review comments & nit
This commit is contained in:
@@ -33,15 +33,14 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
@@ -74,9 +73,6 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
|
||||
Reference in New Issue
Block a user