mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
[rocm-libraries] ROCm/rocm-libraries#4299 (commit 668cd49)
173 implement device grouped gemm fixed nk for rdna4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes This PR adds an RDNA4 implementation of the device_grouped_gemm_fixed_nk instance library using for WMMA. The implementation is based on the existing DeviceGroupedGemm_Xdl_Fixed_NK design and reuses the same high-level structure, but replaces the XDL kernel with a WMMA-based one. It uses the GridwiseGemm_wmma_cshuffle_v3 kernel. At this stage, the focus is functional correctness and compatibility, not performance tuning. ## Technical Details - Device struct for grouped gemm fixed NK - Example code for the WMMA version - Unit tests for both new wmma implementation and the reference XDL code (previously missing) - Generic ck profiler interface with the purpose of calling unit tests. ## Checklist Please put an into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [x] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run on all changed files - [x] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c5ce5eee5b
commit
7b97e197ef
@@ -46,7 +46,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10)
|
||||
{
|
||||
bool pass = true;
|
||||
bool pass = true;
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -54,11 +55,11 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -74,8 +75,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<Tensor<CDataType>> c_m_n_host_results;
|
||||
std::vector<Tensor<CDataType>> c_m_n_device_results;
|
||||
int sum_of_m = 0;
|
||||
|
||||
double max_abs_in_val = 0.f;
|
||||
int sum_of_m = 0;
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += Ms[i];
|
||||
@@ -95,17 +96,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
|
||||
}
|
||||
std::size_t num_thread = 1;
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
|
||||
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n[i]);
|
||||
max_abs_in_val = 5.f;
|
||||
break;
|
||||
default:
|
||||
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
ck::utils::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n[i]);
|
||||
max_abs_in_val = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,23 +284,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
|
||||
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
|
||||
auto atol = ck::utils::get_absolute_threshold<ComputeDataType, CDataType>(
|
||||
max_abs_in_val, gemm_descs[i].K_);
|
||||
auto rtol = ck::utils::get_relative_threshold<ComputeDataType, CDataType>(
|
||||
gemm_descs[i].K_);
|
||||
|
||||
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_results[i],
|
||||
"Error: Incorrect results!",
|
||||
0.06);
|
||||
}
|
||||
else
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_results[i]);
|
||||
}
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_results[i],
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
@@ -315,7 +312,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
@@ -355,7 +352,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
std::cout << "Instance: " << gemm_name
|
||||
<< ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user