mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Add MFMA M16N16K16 and M16N16K32 methods
these two methods are default off
This commit is contained in:
@@ -6,6 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#define mfma_m32_n32_k16 1
|
||||
#define mfma_m32_n32_k8 0
|
||||
#define mfma_m16_n16_k16 0
|
||||
#define mfma_m16_n16_k32 0
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCReg
|
||||
@@ -15,8 +20,9 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if 1
|
||||
#pragma message ("mfma k16")
|
||||
#if mfma_m32_n32_k8
|
||||
#pragma message ("mfma m32 n32 k16")
|
||||
>>>>>>> Stashed changes
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -29,8 +35,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
#pragma message ("mfma k8")
|
||||
#elif mfma_m32_n32_k16
|
||||
#pragma message ("mfma m32 n32 k8")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -43,6 +49,35 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
#elif mfma_m16_n16_k16
|
||||
#pragma message("mfma m16 n16 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif mfma_m16_n16_k32
|
||||
#pragma message("mfma m16 n16 k32")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
|
||||
@@ -42,9 +42,9 @@ int main(int argc, char* argv[])
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
K = std::stoi(argv[3]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
@@ -114,7 +114,7 @@ int main(int argc, char* argv[])
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true},
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
gemm_kernel{},
|
||||
kGridSize,
|
||||
|
||||
Reference in New Issue
Block a user