[rocm-libraries] ROCm/rocm-libraries#4299 (commit 668cd49)

173 implement device grouped gemm fixed nk for rdna4
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

This PR adds an RDNA4 implementation of the device_grouped_gemm_fixed_nk
instance library using for WMMA.

The implementation is based on the existing
DeviceGroupedGemm_Xdl_Fixed_NK design and reuses the same high-level
structure, but replaces the XDL kernel with a WMMA-based one. It uses
the GridwiseGemm_wmma_cshuffle_v3 kernel.

At this stage, the focus is functional correctness and compatibility,
not performance tuning.

## Technical Details

- Device struct for grouped gemm fixed NK
- Example code for the WMMA version
- Unit tests for both new wmma implementation and the reference XDL code
(previously missing)
- Generic ck profiler interface with the purpose of calling unit tests.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [x] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [x] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Márton Bidlek
2026-02-19 08:13:46 +00:00
committed by assistant-librarian[bot]
parent c5ce5eee5b
commit 7b97e197ef
32 changed files with 2819 additions and 163 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -775,8 +775,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
c_thread_buf,
block_work_idx[I0],
block_work_idx[I1],
block_work_idx[I2],
p_shared,
p_ds_grid,
p_e_grid,

View File

@@ -344,6 +344,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<LDSTypeA>, pk_i4_t>)
return 2;
@@ -627,8 +629,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const std::array<index_t, NumATensor>& StrideAs,
const index_t AK0)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding;
@@ -696,8 +697,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const std::array<index_t, NumBTensor>& StrideBs,
const index_t BK0)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding;
@@ -794,7 +794,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// TODO: Investigate why this path is not used in the original
// gridwise_gemm_xdl_cshuffle_v3.hpp
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
@@ -1033,6 +1032,49 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
__host__ __device__ static constexpr bool
CheckValidity(const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch)
{
ignore = StrideDs;
const auto M_padded = CalculateMPadded(M);
const auto N_padded = CalculateMPadded(N);
const auto K_padded = CalculateKPadded(K, KBatch);
const auto e_grid_desc_m_n =
MakeDEGridDescriptor_M_N<ELayout>(M, M_padded, N, N_padded, StrideE);
const index_t AK0 = CalculateAK0Padded(K, KBatch);
const index_t BK0 = CalculateBK0Padded(K, KBatch);
const auto a_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
M, M_padded, K, K_padded, std::array<index_t, 1>{StrideA}, AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
K, K_padded, N, N_padded, std::array<index_t, 1>{StrideB}, BK0);
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const auto& a_desc = a_grid_desc_ak0_m_ak1.At(I0);
const auto& b_desc = b_grid_desc_bk0_n_bk1.At(I0);
if(!(a_desc.GetElementSpaceSize() * sizeof(LDSTypeA) <= TwoGB &&
b_desc.GetElementSpaceSize() * sizeof(LDSTypeB) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ static constexpr bool CheckValidity(const Argument& karg,
@@ -1089,9 +1131,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! "
"K_Batch:"
<< karg.KBatch << " " << "K0PerBlock:" << KPerBlock << " "
<< "K1:" << AK1Number << " " << "K:" << karg.K << " " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}