mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
draft for gfx950,32*32*64mfma test
This commit is contained in:
@@ -16,6 +16,7 @@ template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
ck_tile::index_t GemmCfgVer,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
@@ -29,32 +30,19 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
constexpr int kBlockPerCu = 2;
|
||||
|
||||
// This part comes from the Codegen
|
||||
#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 128;
|
||||
static_assert(sizeof(ADataType) == 2 || sizeof(ADataType) == 1);
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig<BDataType, GemmCfgVer>::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig<BDataType, GemmCfgVer>::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig<BDataType, GemmCfgVer>::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig<BDataType, GemmCfgVer>::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig<BDataType, GemmCfgVer>::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig<BDataType, GemmCfgVer>::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig<BDataType, GemmCfgVer>::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig<BDataType, GemmCfgVer>::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig<BDataType, GemmCfgVer>::K_Warp_Tile;
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8)
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 8;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
|
||||
#endif
|
||||
using CodegenFlatmmShape =
|
||||
ck_tile::TileFlatmmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
@@ -134,6 +122,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
#if 0
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -176,5 +165,5 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
#endif
|
||||
int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
|
||||
|
||||
Reference in New Issue
Block a user