Fix broken ping pong pipeline functionality in the develop branch.

This commit is contained in:
Sudhir Kylasa
2025-09-16 18:31:35 +00:00
parent 7d7ded62d3
commit e1c87eac08
3 changed files with 4 additions and 2 deletions

View File

@@ -252,7 +252,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>

View File

@@ -102,6 +102,7 @@ struct UniversalInvoker
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::K_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,

View File

@@ -40,6 +40,7 @@ template <typename ADataType_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t KWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
@@ -59,7 +60,7 @@ struct CShuffleEpilogueProblem
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;