Fix split-k gemm test (#231)

* properly return error flag; reveals bug in split-k gemm

* fix bug in split k

* update split-k test case

Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Anthony Chang
2022-11-30 00:57:26 +08:00
committed by GitHub
parent 0e9c88cecf
commit 236bd148b9

View File

@@ -226,9 +226,8 @@ int main(int argc, char* argv[])
std::vector<gemmArgs> test_cases;
if(argc == 1)
{
test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}};
// JD: Populate with more and meaningful
return 0;
test_cases = {{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 2},
{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 8}};
}
else if(argc == 9)
{
@@ -253,11 +252,10 @@ int main(int argc, char* argv[])
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return -1;
}
bool error = false;
for(const auto& kinder : test_cases)
{
const auto res = test_gemm(kinder);
if(!res)
return -1;
error |= test_gemm(kinder);
}
return 0;
return error ? 1 : 0;
}