mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[CK_TILE] Grouped GEMM tile loop (#2146)
* Add trait to use a persistent kernel and split the entrypoints in grouped gemm * Some helper functions for persistent kernel case * Get max occupancy grid using device properties * Implement tile loop in main entry point to grouped gemm * Enable GridSize() on device * Handle offset tile index using real current block index * Add persistent kernel choice to grouped gemm example * Use a for-loop for iterating over the group * Reduce VGPR spills by early-exit * Enable persistent kernel choice in grouped_gemm example * Add persistent kernel option to grouped_gemm test * Fix formatting with remod.py * Remove GridUpdateBlocks as blocks are now iteratively computed * Add comment about VGPR spilling * Fix formatting * Use CK_TILE_HOST instead of __host__ * Enable all Row/Col combinations in grouped gemm unit test * Add some KBatch=2 cases to grouped gemm tests * Fix SplitK for grouped gemm * Enable pipeline hotloop/tailnumber selection in-kernel for grouped gemm * Add type traits * Split examples to regular and tileloop * Formatting * Use hipExtStreamGetCUMask to get current active CUs for the given stream * Align test and example kernel config, and disable validation for splitk repeats * Remove debug options from CMakeLists.txt * Separate the code paths for persistent/non-persistent in test * Fix formatting * Address review comments --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -20,18 +20,19 @@ namespace ck_tile {
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV3
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
@@ -104,6 +105,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
using Base::PrefetchStages;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
|
||||
0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Executable file → Normal file
0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Executable file → Normal file
@@ -38,7 +38,8 @@ template <bool kPadM_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -53,6 +54,27 @@ struct TileGemmUniversalTraits
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
|
||||
kPadN_,
|
||||
kPadK_,
|
||||
DoubleSmemBuffer_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
CLayout_,
|
||||
TransposeC_,
|
||||
UseStructuredSparsity_,
|
||||
true>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user