mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Add a gpu gemm reference kernel (#1528)
* Add a gpu gemm reference kernel * Switch to gpu reference in gemm examples * Remove redundant arguments * Update all related examples * Update more examples * Try less threads per block * Try even less threads per block * Add support for all matrix layouts * Increase block size * Clean up * Remove hardcoded strides * Clean up * Try a column-major case * Revert back to row-major * Run both CPU and GPU veriffication --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -173,6 +173,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
@@ -193,6 +194,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) *
|
||||
c_m_n_device_ref_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
@@ -325,14 +328,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
// CPU verification
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
std::cout << "Running verification on CPU." << std::endl;
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
@@ -346,15 +353,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
pass &= !ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
#endif
|
||||
|
||||
// GPU verification
|
||||
auto ref_gemm_gpu = ReferenceGemmInstanceGPU{};
|
||||
auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();
|
||||
|
||||
auto ref_argument_gpu = ref_gemm_gpu.MakeArgument(
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_ref_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
std::cout << "Running verification on GPU." << std::endl;
|
||||
ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{});
|
||||
|
||||
c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data());
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= !ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_device_ref_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
return true;
|
||||
return !pass;
|
||||
}
|
||||
|
||||
bool run_gemm_example(int argc, char* argv[])
|
||||
|
||||
Reference in New Issue
Block a user