feat(gemm_wp): add two new configs for gemm weight preshuffle in gemm_utils.h (#2690)

* feat(gemm_wp): add two new configs for wp

* delete the unnecessary files

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Aviral Goel
2025-08-15 18:00:25 -04:00
committed by GitHub
parent 1c2078066b
commit c06e8b4a66
3 changed files with 15 additions and 28 deletions

View File

@@ -12,6 +12,8 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_universal -j
# The weight preshuffle pipeline on the gemm calculation
make tile_example_gemm_weight_preshuffle -j
```
This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`

View File

@@ -34,21 +34,6 @@ constexpr ck_tile::index_t get_k_warp_tile()
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
struct GemmConfigBase
{
@@ -232,11 +217,11 @@ struct GemmConfigComputeV5 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshuffle_1 : public GemmConfigBase
struct GemmConfigPreshuffleDecode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -244,17 +229,17 @@ struct GemmConfigPreshuffle_1 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigPreshuffle_2 : public GemmConfigBase
struct GemmConfigPreshufflePrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -266,7 +251,7 @@ struct GemmConfigPreshuffle_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -470,7 +455,7 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1", "rotating count, defaults to 1");
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -141,7 +141,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
}
if(s.flush_cache_)
{
@@ -280,7 +280,7 @@ int main(int argc, char* argv[])
try
{
return !run_gemm_example<GemmConfigPreshuffle_2>(arg_parser);
return !run_gemm_example<GemmConfigPreshuffleDecode>(arg_parser);
}
catch(const std::runtime_error& e)
{