ctest of batched_gemm returns 0 or 1 (#149)

* ctest of batched_gemm returns 0 or 1

* minor change

[ROCm/composable_kernel commit: 313bbea588]
This commit is contained in:
Jianfeng Yan
2022-03-24 19:38:02 -05:00
committed by GitHub
parent 73a1a1f0d8
commit e88e4d8b5b
2 changed files with 11 additions and 13 deletions

View File

@@ -109,13 +109,13 @@ bool TestBatchedGemm(const std::size_t batch_count, DeviceBatchedGemmPtr& gemmPt
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert
// bool res = test::check_err(
// bool pass = test::check_err(
// c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
bool res = check_error(c_device, c_host) < 0.007815f;
bool pass = check_error(c_device, c_host) < 0.007815f;
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
std::cout << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return res;
return pass;
}
} // namespace
@@ -125,13 +125,15 @@ int main()
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(batched_gemm_ptrs);
bool res = true;
bool pass = true;
const std::size_t batch_count = 4;
for(auto& gemmPtr : batched_gemm_ptrs)
{
res &= TestBatchedGemm(batch_count, gemmPtr);
pass &= TestBatchedGemm(batch_count, gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}

View File

@@ -14,12 +14,8 @@ int main(int argc, char** argv)
(void)argc;
(void)argv;
{
traverse_using_space_filling_curve();
auto err = hipDeviceSynchronize();
(void)err;
assert(err == hipSuccess);
}
traverse_using_space_filling_curve();
return 0;
}