Fix for Unsupported Input Shapes/Sizes in Stream-K GEMM - BF16/FP16 (#1866)

This commit is contained in:
Muhammed Emin Ozturk
2025-02-18 08:46:47 -08:00
committed by GitHub
parent c287418dcc
commit 92b79ead0a
5 changed files with 85 additions and 16 deletions

View File

@@ -56,6 +56,26 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
exit(1);
}
int M;
int N;
int StrideA;
int StrideB;
// Analyze the unsupported matrix shapes, switch the M and N number
if(std::stoi(argv[9]) % 8 != 0 && std::stoi(argv[8]) % 8 == 0)
{
M = std::stoi(argv[9]);
StrideA = std::stoi(argv[12]);
N = std::stoi(argv[8]);
StrideB = std::stoi(argv[11]);
}
else
{
M = std::stoi(argv[8]);
StrideA = std::stoi(argv[11]);
N = std::stoi(argv[9]);
StrideB = std::stoi(argv[12]);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
@@ -63,12 +83,8 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int Streamk_sel = std::stoi(argv[14]);
const int Grid_size = std::stoi(argv[15]);