mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
fixed wmma_op test
This commit is contained in:
@@ -54,11 +54,6 @@ bool run_test()
|
||||
}
|
||||
int main(int, char*[])
|
||||
{
|
||||
int deviceCount;
|
||||
std::cout << hipGetDeviceCount(&deviceCount) << std::endl;
|
||||
std::cout << deviceCount << std::endl;
|
||||
std::cout << hipSetDevice(2) << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
// clang-format off
|
||||
// |SrcType |DstType |GPUAccType |CPUAccType |AccNum
|
||||
@@ -67,7 +62,9 @@ int main(int, char*[])
|
||||
pass &= run_test<ck::half_t, ck::half_t, ck::half_t, ck::half_t, 16 >();
|
||||
pass &= run_test<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, float, 16 >();
|
||||
pass &= run_test<int8_t, int8_t, int32_t, int32_t, 8 >();
|
||||
// pass &= run_test<ck::f8_t, ck::f8_t, float, float, 8 >();
|
||||
#if defined(CK_USE_WMMA_FP8)
|
||||
pass &= run_test<ck::f8_t, ck::f8_t, float, float, 8 >();
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
|
||||
@@ -98,8 +98,6 @@ builtin_wmma_naive_selector<int4x16_t,
|
||||
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
|
||||
__global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
|
||||
{
|
||||
printf("dev matmul cicc\n");
|
||||
|
||||
__shared__ src_t p_shared[16 * 16 * 2];
|
||||
const int lIdx = threadIdx.x;
|
||||
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
|
||||
@@ -199,8 +197,6 @@ __global__ void matmul_swizzle_a(const src_t* a, const src_t* b, dst_t* c)
|
||||
{
|
||||
const int lIdx = threadIdx.x;
|
||||
|
||||
printf("dev matmul_swizzle_a cicc\n");
|
||||
|
||||
using src_vec = typename vector_type<src_t, 16>::type;
|
||||
src_vec a_frag = {};
|
||||
src_vec b_frag = {};
|
||||
@@ -377,54 +373,33 @@ struct TestWmma
|
||||
ck::wmma_op_util::RunHostGEMM<ReferenceGemmInstance>(
|
||||
a, b, c_host, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
// Act
|
||||
bool is_supported = (ck::is_gfx11_supported() || ck::is_gfx12_supported()) &&
|
||||
ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
|
||||
// Unsupported types should be filtered out before calling test operator.
|
||||
bool res = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
|
||||
|
||||
if(is_supported)
|
||||
if(std::is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
// Assert
|
||||
bool res = false;
|
||||
if(std::is_same<CDataType, float>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, ck::half_t>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
|
||||
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
|
||||
res = ck::utils::check_err(
|
||||
c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, double>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "UNSUPPORTED CDataType" << std::endl;
|
||||
}
|
||||
|
||||
return res;
|
||||
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
|
||||
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
|
||||
res = ck::utils::check_err(
|
||||
c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, float>::value ||
|
||||
std::is_same<CDataType, ck::half_t>::value ||
|
||||
std::is_same<CDataType, int8_t>::value ||
|
||||
std::is_same<CDataType, double>::value ||
|
||||
std::is_same<CDataType, f8_t>::value)
|
||||
{
|
||||
// Run with default error thresholds.
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "UNSUPPORTED hardware. Skipping test." << std::endl;
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user