mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Refine fp8 support in flatmm (#2239)
* [CK_TILE] Refine fp8 in flatmm 1. Replace USING_MFMA_16x16x32 & USING_MFMA_16x16x32 with constexpr 2. Add an additional const check to avoid build error in HotLoopScheduler 3. Refine shuffleb to support both tile 32x32 and 16x16 4. Support command option -init 5. Move Gemm warp defintion to a separate struct * fix clang format * fix clang format * keep default bhavior unchanged (warp tile = 16x16) * fix tile engine build error * fix a typo in codegen_utils.py * address review comments * address review comments --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -31,7 +31,63 @@
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
// GEMM config with 32x132 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
@@ -103,47 +159,10 @@ struct DataTypeTraits<ck_tile::half_t>
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_8bit_type
|
||||
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct FlatmmConfig
|
||||
{
|
||||
#if defined(USING_MFMA_16x16x32)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 8;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
|
||||
#endif
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -163,8 +182,11 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
@@ -174,6 +196,7 @@ template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
|
||||
Reference in New Issue
Block a user