Implement batched gemm bias permute for RDNA4 (#3534)

* feat: test setup for batched contraction (aka batched gemm multiple d e permute)

* wip: device struct for WMMA batched contraction multiple d based on new gridwise op

* feat: working batched contraction on RDNA, non-naive tensor descriptors for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases

* fix: failure to resolve template parameters when calling new function overload

* fix: passing reference type as parameter instead of underlying types

* fix: merge error caused duplicate definitions

* fix: make sure constness of template and parameters types match

* fix: don't compile batched contraction test on unsupported architectures

* feat: add example for new wmma implementation, and consolidate example code between platforms

* style: return inline instead of with branch

* chore: add extra assert on vector memory access sizes

* chore: clean up some unused variables

* fix: correct tail number calculation, added small cases and extra instances to the test

* fix: properly support wave transfer by generating correct grid descriptors dependent on the transfer method
This commit is contained in:
Erwin Terpstra
2026-01-17 08:30:27 +01:00
committed by GitHub
parent f9104ef9b3
commit fe40a5d139
18 changed files with 2475 additions and 1009 deletions

View File

@@ -414,22 +414,22 @@ struct GridwiseGemm_wmma_cshuffle_v3
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument() = default;
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
__host__ __device__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
p_as_grid{},
p_bs_grid{},
@@ -607,6 +607,67 @@ struct GridwiseGemm_wmma_cshuffle_v3
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Block2CTileMap,
EpilogueArgument,
BlockMapMBlockIndex,
BlockMapNBlockIndex>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
a_element_op,
b_element_op,
cde_element_op,
epilogue_args,
A_k_id,
B_k_id);
}
// Overload to pass in custom As/Bs/Ds/E grid descriptors
// Used for contraction operations, where tensor transforms are non-trivial
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename AsGridDescriptor_AK0_M_AK1,
typename BsGridDescriptor_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename EpilogueArgument,
int BlockMapMBlockIndex = 0,
int BlockMapNBlockIndex = 1>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const AsGridDescriptor_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
const BsGridDescriptor_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args,
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
@@ -773,9 +834,13 @@ struct GridwiseGemm_wmma_cshuffle_v3
B_k_id);
}
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
__device__ __host__ static auto DefaultBlock2CTileMap(const Problem& problem)
{
return Block2CTileMap{problem.M, problem.N, 4};
return DefaultBlock2CTileMap(problem.M, problem.N);
}
__device__ __host__ static auto DefaultBlock2CTileMap(const index_t M, const index_t N)
{
return Block2CTileMap{M, N, 4};
}
// Run method for convolution for bwd_data (grid descriptors are passed as arguments,

View File

@@ -499,8 +499,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
template <typename BaseDescriptors_M_K>
__host__ __device__ static auto
MakeAsGridDescriptor_AK0_M_AK1(const index_t M,
MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs,
const index_t M,
const index_t MPad,
const index_t K,
const index_t KPad,
@@ -518,10 +520,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
GemmSpec == GemmSpecialization::NKPadding;
return generate_tuple(
[&](auto i) {
const auto base_desc = MakeAGridDescriptor_M_K(M, K, StrideAs[i]);
return ATransfer::template MakeGridDescriptor<padM, padK>(
base_desc, M, MPad, K, KPad, StrideAs[i], AK0);
base_descs[i], M, MPad, K, KPad, StrideAs[i], AK0);
},
Number<NumATensor>{});
}
@@ -539,8 +539,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return ATransfer::template MakeGridDescriptor<padM, padK>(base_desc, M, M, K, K, 0, AK0);
}
template <typename BaseDescriptors_M_K>
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs, const index_t KBatch = 1)
{
const index_t M = base_descs.At(I0).GetLength(I0);
const index_t K = base_descs.At(I0).GetLength(I1);
const index_t MPad = CalculateMPadded(M);
const index_t KPad = CalculateKPadded(K, KBatch);
const index_t AK0 = CalculateAK0Padded(K, KBatch);
return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, {}, AK0);
}
__host__ __device__ static auto
MakeAsGridDescriptor_AK0_M_AK1(const index_t M,
const index_t MPad,
const index_t K,
const index_t KPad,
const std::array<index_t, NumATensor>& StrideAs,
const index_t AK0)
{
const auto base_descs =
generate_tuple([&](auto i) { return MakeAGridDescriptor_M_K(M, K, StrideAs[i]); },
Number<NumATensor>{});
return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, StrideAs, AK0);
}
template <typename BaseDescriptors_N_K>
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs,
const index_t K,
const index_t KPad,
const index_t N,
const index_t NPad,
@@ -558,9 +589,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
GemmSpec == GemmSpecialization::MKPadding;
return generate_tuple(
[&](auto i) {
const auto base_desc = MakeBGridDescriptor_N_K(N, K, StrideBs[i]);
return BTransfer::template MakeGridDescriptor<padN, padK>(
base_desc, N, NPad, K, KPad, StrideBs[i], BK0);
base_descs[i], N, NPad, K, KPad, StrideBs[i], BK0);
},
Number<NumBTensor>{});
}
@@ -578,6 +608,36 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return BTransfer::template MakeGridDescriptor<padN, padK>(base_desc, N, N, K, K, 0, BK0);
}
template <typename BaseDescriptors_N_K>
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs, const index_t KBatch = 1)
{
const index_t N = base_descs.At(I0).GetLength(I0);
const index_t K = base_descs.At(I0).GetLength(I1);
const index_t NPad = CalculateNPadded(N);
const index_t KPad = CalculateKPadded(K, KBatch);
const index_t BK0 = CalculateBK0Padded(K, KBatch);
return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, {}, BK0);
}
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
const index_t KPad,
const index_t N,
const index_t NPad,
const std::array<index_t, NumBTensor>& StrideBs,
const index_t BK0)
{
const auto base_descs =
generate_tuple([&](auto i) { return MakeBGridDescriptor_N_K(N, K, StrideBs[i]); },
Number<NumBTensor>{});
return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, StrideBs, BK0);
}
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor()
{
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
@@ -681,7 +741,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
template <typename DsGridDesc>
__device__ __host__ static constexpr auto
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)