mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
merge flatmm pipe v0 from dteng_flatmm_opt
This commit is contained in:
@@ -86,7 +86,7 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
@@ -167,119 +167,119 @@ struct is_8bit_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmConfig
|
||||
{
|
||||
#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
// template <typename DataType>
|
||||
// struct GemmConfig
|
||||
// {
|
||||
// #if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 256;
|
||||
// static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
// #elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 128;
|
||||
// static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
// #elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 128;
|
||||
// static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
// #elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 128;
|
||||
// static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
// #elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
|
||||
// static constexpr ck_tile::index_t M_Tile = 16;
|
||||
// static constexpr ck_tile::index_t N_Tile = 64;
|
||||
// static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
// #elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 256;
|
||||
// static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 8;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 8;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
// #elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 128;
|
||||
// static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
// #elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 128;
|
||||
// static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#else
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
// #else
|
||||
// static constexpr ck_tile::index_t M_Tile = 128;
|
||||
// static constexpr ck_tile::index_t N_Tile = 256;
|
||||
// static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
// static constexpr ck_tile::index_t M_Warp = 1;
|
||||
// static constexpr ck_tile::index_t N_Warp = 4;
|
||||
// static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
#endif
|
||||
};
|
||||
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
// static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
// #endif
|
||||
// };
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user