Add MFMA M16N16K16 and M16N16K32 methods

these two methods are default off
This commit is contained in:
bobofang
2025-03-28 16:30:42 +00:00
committed by Philip Maybank
parent e866f814f9
commit 127e742e96
2 changed files with 43 additions and 8 deletions

View File

@@ -6,6 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#define mfma_m32_n32_k16 1
#define mfma_m32_n32_k8 0
#define mfma_m16_n16_k16 0
#define mfma_m16_n16_k32 0
namespace ck_tile {
// Default policy for BlockGemmASmemBSmemCReg
@@ -15,8 +20,9 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if 1
#pragma message ("mfma k16")
#if mfma_m32_n32_k8
#pragma message ("mfma m32 n32 k16")
>>>>>>> Stashed changes
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -29,8 +35,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
#else
#pragma message ("mfma k8")
#elif mfma_m32_n32_k16
#pragma message ("mfma m32 n32 k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -43,6 +49,35 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
#elif mfma_m16_n16_k16
#pragma message("mfma m16 n16 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
#elif mfma_m16_n16_k32
#pragma message("mfma m16 n16 k32")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 4, 1);
}
#endif
else
{

View File

@@ -42,9 +42,9 @@ int main(int argc, char* argv[])
if(argc == 5)
{
verification = std::stoi(argv[1]);
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
K = std::stoi(argv[3]);
M = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
}
const ck_tile::index_t Lda = K;
@@ -114,7 +114,7 @@ int main(int argc, char* argv[])
kGemmKPerBlock>;
float ave_time =
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true},
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
gemm_kernel{},
kGridSize,