mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
feat(gemm_wp): add two new configs for gemm weight preshuffle in gemm_utils.h (#2690)
* feat(gemm_wp): add two new configs for wp * delete the unnecessary files --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -12,6 +12,8 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_gemm_basic -j
|
||||
# The memory bound pipeline on the gemm calculation
|
||||
make tile_example_gemm_universal -j
|
||||
# The weight preshuffle pipeline on the gemm calculation
|
||||
make tile_example_gemm_weight_preshuffle -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`
|
||||
|
||||
|
||||
@@ -34,21 +34,6 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
@@ -232,11 +217,11 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffle_1 : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -244,17 +229,17 @@ struct GemmConfigPreshuffle_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffle_2 : public GemmConfigBase
|
||||
struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -266,7 +251,7 @@ struct GemmConfigPreshuffle_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -470,7 +455,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1", "rotating count, defaults to 1");
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -141,7 +141,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
@@ -280,7 +280,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshuffle_2>(arg_parser);
|
||||
return !run_gemm_example<GemmConfigPreshuffleDecode>(arg_parser);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user