diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp index 6e9e866548..7e4649d969 100644 --- a/test/wmma_op/wmma_op.cpp +++ b/test/wmma_op/wmma_op.cpp @@ -54,11 +54,6 @@ bool run_test() } int main(int, char*[]) { - int deviceCount; - std::cout << hipGetDeviceCount(&deviceCount) << std::endl; - std::cout << deviceCount << std::endl; - std::cout << hipSetDevice(2) << std::endl; - bool pass = true; // clang-format off // |SrcType |DstType |GPUAccType |CPUAccType |AccNum @@ -67,7 +62,9 @@ int main(int, char*[]) pass &= run_test(); pass &= run_test(); pass &= run_test(); - // pass &= run_test(); +#if defined(CK_USE_WMMA_FP8) + pass &= run_test(); +#endif // clang-format on std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index 179cf5647b..25ed6709e8 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -98,8 +98,6 @@ builtin_wmma_naive_selector __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) { - printf("dev matmul cicc\n"); - __shared__ src_t p_shared[16 * 16 * 2]; const int lIdx = threadIdx.x; // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and @@ -199,8 +197,6 @@ __global__ void matmul_swizzle_a(const src_t* a, const src_t* b, dst_t* c) { const int lIdx = threadIdx.x; - printf("dev matmul_swizzle_a cicc\n"); - using src_vec = typename vector_type::type; src_vec a_frag = {}; src_vec b_frag = {}; @@ -377,54 +373,33 @@ struct TestWmma ck::wmma_op_util::RunHostGEMM( a, b, c_host, a_element_op, b_element_op, c_element_op); - // Act - bool is_supported = (ck::is_gfx11_supported() || ck::is_gfx12_supported()) && - ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); + // Unsupported types should be filtered out before calling test operator. + bool res = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); - if(is_supported) + if(std::is_same::value) { - // Assert - bool res = false; - if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - else if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - else if(std::is_same::value) - { - // 0.5 Pixel Error Tolerance is introduced by Accumulator difference. - // BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float. - res = ck::utils::check_err( - c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - else if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - else if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - else - { - std::cout << "UNSUPPORTED CDataType" << std::endl; - } - - return res; + // 0.5 Pixel Error Tolerance is introduced by Accumulator difference. + // BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float. + res = ck::utils::check_err( + c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) + { + // Run with default error thresholds. + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } else { - std::cout << "UNSUPPORTED hardware. Skipping test." << std::endl; - return true; + return false; } + + return res; } };