mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix a bug that K_Warp_Tile not match in two files.
This commit is contained in:
@@ -27,46 +27,18 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
#if defined(USING_MFMA_16x16x128)
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig<BDataType>::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig<BDataType>::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig<BDataType>::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
#endif
|
||||
// This part comes from the Codegen
|
||||
#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 128;
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig<BDataType>::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig<BDataType>::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig<BDataType>::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8)
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 8;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
|
||||
#endif
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig<BDataType>::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig<BDataType>::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig<BDataType>::K_Warp_Tile;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
|
||||
@@ -31,21 +31,6 @@
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
};
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
@@ -124,6 +109,61 @@ struct is_8bit_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmConfig
|
||||
{
|
||||
#if defined(USING_MFMA_16x16x128)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
#elif defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<DataType>::value ? 16 : 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<DataType>::value ? 16 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<DataType>::value ? 64 : 16;
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 8;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<DataType>::value ? 32 : 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<DataType>::value ? 32 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<DataType>::value ? 32 : 16;
|
||||
#else
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
#endif
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
Reference in New Issue
Block a user