add the Preshuffle3

This commit is contained in:
ThomasNing
2025-07-18 07:11:14 +00:00
parent 793c2c5c3c
commit b89ddd23e5
2 changed files with 25 additions and 3 deletions

View File

@@ -233,7 +233,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
struct GemmConfigPreshuffle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -255,7 +255,7 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
struct GemmConfigPreshuffle_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -276,6 +276,28 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigPreshuffle_3 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

View File

@@ -282,7 +282,7 @@ int main(int argc, char* argv[])
{
try
{
return !run_gemm_example<GemmConfigPreshufle_2>(argc, argv);
return !run_gemm_example<GemmConfigPreshuffle_3>(argc, argv);
}
catch(const std::runtime_error& e)
{