mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4299 (commit 668cd49)
173 implement device grouped gemm fixed nk for rdna4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes This PR adds an RDNA4 implementation of the device_grouped_gemm_fixed_nk instance library using for WMMA. The implementation is based on the existing DeviceGroupedGemm_Xdl_Fixed_NK design and reuses the same high-level structure, but replaces the XDL kernel with a WMMA-based one. It uses the GridwiseGemm_wmma_cshuffle_v3 kernel. At this stage, the focus is functional correctness and compatibility, not performance tuning. ## Technical Details - Device struct for grouped gemm fixed NK - Example code for the WMMA version - Unit tests for both new wmma implementation and the reference XDL code (previously missing) - Generic ck profiler interface with the purpose of calling unit tests. ## Checklist Please put an into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [x] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run on all changed files - [x] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c5ce5eee5b
commit
7b97e197ef
File diff suppressed because it is too large
Load Diff
@@ -775,8 +775,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_thread_buf,
|
||||
block_work_idx[I0],
|
||||
block_work_idx[I1],
|
||||
block_work_idx[I2],
|
||||
p_shared,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
|
||||
@@ -344,6 +344,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<LDSTypeA>, pk_i4_t>)
|
||||
return 2;
|
||||
@@ -627,8 +629,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const std::array<index_t, NumATensor>& StrideAs,
|
||||
const index_t AK0)
|
||||
{
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
|
||||
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding;
|
||||
@@ -696,8 +697,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const std::array<index_t, NumBTensor>& StrideBs,
|
||||
const index_t BK0)
|
||||
{
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
|
||||
constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding;
|
||||
@@ -794,7 +794,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
// TODO: Investigate why this path is not used in the original
|
||||
// gridwise_gemm_xdl_cshuffle_v3.hpp
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
@@ -1033,6 +1032,49 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const index_t KBatch)
|
||||
{
|
||||
|
||||
ignore = StrideDs;
|
||||
const auto M_padded = CalculateMPadded(M);
|
||||
const auto N_padded = CalculateMPadded(N);
|
||||
const auto K_padded = CalculateKPadded(K, KBatch);
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
MakeDEGridDescriptor_M_N<ELayout>(M, M_padded, N, N_padded, StrideE);
|
||||
|
||||
const index_t AK0 = CalculateAK0Padded(K, KBatch);
|
||||
const index_t BK0 = CalculateBK0Padded(K, KBatch);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
M, M_padded, K, K_padded, std::array<index_t, 1>{StrideA}, AK0);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
|
||||
K, K_padded, N, N_padded, std::array<index_t, 1>{StrideB}, BK0);
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const auto& a_desc = a_grid_desc_ak0_m_ak1.At(I0);
|
||||
const auto& b_desc = b_grid_desc_bk0_n_bk1.At(I0);
|
||||
|
||||
if(!(a_desc.GetElementSpaceSize() * sizeof(LDSTypeA) <= TwoGB &&
|
||||
b_desc.GetElementSpaceSize() * sizeof(LDSTypeB) <= TwoGB &&
|
||||
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Argument>
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg,
|
||||
@@ -1089,9 +1131,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! "
|
||||
"K_Batch:"
|
||||
<< karg.KBatch << " " << "K0PerBlock:" << KPerBlock << " "
|
||||
<< "K1:" << AK1Number << " " << "K:" << karg.K << " " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user