fixed wmma_op test

This commit is contained in:
Zoltan Lakatos
2025-06-19 13:01:51 +00:00
parent 5e454276e3
commit a8dec7a4da
2 changed files with 24 additions and 52 deletions

View File

@@ -54,11 +54,6 @@ bool run_test()
}
int main(int, char*[])
{
int deviceCount;
std::cout << hipGetDeviceCount(&deviceCount) << std::endl;
std::cout << deviceCount << std::endl;
std::cout << hipSetDevice(2) << std::endl;
bool pass = true;
// clang-format off
// |SrcType |DstType |GPUAccType |CPUAccType |AccNum
@@ -67,7 +62,9 @@ int main(int, char*[])
pass &= run_test<ck::half_t, ck::half_t, ck::half_t, ck::half_t, 16 >();
pass &= run_test<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, float, 16 >();
pass &= run_test<int8_t, int8_t, int32_t, int32_t, 8 >();
// pass &= run_test<ck::f8_t, ck::f8_t, float, float, 8 >();
#if defined(CK_USE_WMMA_FP8)
pass &= run_test<ck::f8_t, ck::f8_t, float, float, 8 >();
#endif
// clang-format on
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;

View File

@@ -98,8 +98,6 @@ builtin_wmma_naive_selector<int4x16_t,
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
__global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
{
printf("dev matmul cicc\n");
__shared__ src_t p_shared[16 * 16 * 2];
const int lIdx = threadIdx.x;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
@@ -199,8 +197,6 @@ __global__ void matmul_swizzle_a(const src_t* a, const src_t* b, dst_t* c)
{
const int lIdx = threadIdx.x;
printf("dev matmul_swizzle_a cicc\n");
using src_vec = typename vector_type<src_t, 16>::type;
src_vec a_frag = {};
src_vec b_frag = {};
@@ -377,54 +373,33 @@ struct TestWmma
ck::wmma_op_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
bool is_supported = (ck::is_gfx11_supported() || ck::is_gfx12_supported()) &&
ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
// Unsupported types should be filtered out before calling test operator.
bool res = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
if(is_supported)
if(std::is_same<CDataType, ck::bhalf_t>::value)
{
// Assert
bool res = false;
if(std::is_same<CDataType, float>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::half_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::bhalf_t>::value)
{
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
res = ck::utils::check_err(
c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, int8_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else
{
std::cout << "UNSUPPORTED CDataType" << std::endl;
}
return res;
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
res = ck::utils::check_err(
c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, float>::value ||
std::is_same<CDataType, ck::half_t>::value ||
std::is_same<CDataType, int8_t>::value ||
std::is_same<CDataType, double>::value ||
std::is_same<CDataType, f8_t>::value)
{
// Run with default error thresholds.
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else
{
std::cout << "UNSUPPORTED hardware. Skipping test." << std::endl;
return true;
return false;
}
return res;
}
};