merge flatmm pipe v0 from dteng_flatmm_opt

This commit is contained in:
valarLip
2025-07-23 08:44:12 +00:00
parent 6dacf833da
commit 89fa639207
5 changed files with 987 additions and 105 deletions

View File

@@ -86,7 +86,7 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 2;
};
template <typename ADataType>
@@ -167,119 +167,119 @@ struct is_8bit_type
{
};
template <typename DataType>
struct GemmConfig
{
#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
// template <typename DataType>
// struct GemmConfig
// {
// #if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 256;
// static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 128;
// #elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 64;
// #elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
// #elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 16;
// #elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
// static constexpr ck_tile::index_t M_Tile = 16;
// static constexpr ck_tile::index_t N_Tile = 64;
// static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 64;
// #elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 256;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 8;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 8;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
// #elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
// #elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 16;
// #else
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 256;
// static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#endif
};
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 128;
// #endif
// };
auto create_args(int argc, char* argv[])
{