mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement batched gemm bias permute for RDNA4 (#3534)
* feat: test setup for batched contraction (aka batched gemm multiple d e permute) * wip: device struct for WMMA batched contraction multiple d based on new gridwise op * feat: working batched contraction on RDNA, non-naive tensor descriptors for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases * fix: failure to resolve template parameters when calling new function overload * fix: passing reference type as parameter instead of underlying types * fix: merge error caused duplicate definitions * fix: make sure constness of template and parameters types match * fix: don't compile batched contraction test on unsupported architectures * feat: add example for new wmma implementation, and consolidate example code between platforms * style: return inline instead of with branch * chore: add extra assert on vector memory access sizes * chore: clean up some unused variables * fix: correct tail number calculation, added small cases and extra instances to the test * fix: properly support wave transfer by generating correct grid descriptors dependent on the transfer method
This commit is contained in:
@@ -414,22 +414,22 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public Problem
|
||||
{
|
||||
__host__ Argument() = default;
|
||||
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
|
||||
std::array<const void*, NumBTensor> p_bs_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
__host__ __device__ Argument(std::array<const void*, NumATensor> p_as_grid_,
|
||||
std::array<const void*, NumBTensor> p_bs_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
|
||||
p_as_grid{},
|
||||
p_bs_grid{},
|
||||
@@ -607,6 +607,67 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Block2CTileMap,
|
||||
EpilogueArgument,
|
||||
BlockMapMBlockIndex,
|
||||
BlockMapNBlockIndex>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args,
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
// Overload to pass in custom As/Bs/Ds/E grid descriptors
|
||||
// Used for contraction operations, where tensor transforms are non-trivial
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename AsGridDescriptor_AK0_M_AK1,
|
||||
typename BsGridDescriptor_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename EpilogueArgument,
|
||||
int BlockMapMBlockIndex = 0,
|
||||
int BlockMapNBlockIndex = 1>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const AsGridDescriptor_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
|
||||
const BsGridDescriptor_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
@@ -773,9 +834,13 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
|
||||
__device__ __host__ static auto DefaultBlock2CTileMap(const Problem& problem)
|
||||
{
|
||||
return Block2CTileMap{problem.M, problem.N, 4};
|
||||
return DefaultBlock2CTileMap(problem.M, problem.N);
|
||||
}
|
||||
__device__ __host__ static auto DefaultBlock2CTileMap(const index_t M, const index_t N)
|
||||
{
|
||||
return Block2CTileMap{M, N, 4};
|
||||
}
|
||||
|
||||
// Run method for convolution for bwd_data (grid descriptors are passed as arguments,
|
||||
|
||||
@@ -499,8 +499,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BaseDescriptors_M_K>
|
||||
__host__ __device__ static auto
|
||||
MakeAsGridDescriptor_AK0_M_AK1(const index_t M,
|
||||
MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs,
|
||||
const index_t M,
|
||||
const index_t MPad,
|
||||
const index_t K,
|
||||
const index_t KPad,
|
||||
@@ -518,10 +520,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
GemmSpec == GemmSpecialization::NKPadding;
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto base_desc = MakeAGridDescriptor_M_K(M, K, StrideAs[i]);
|
||||
|
||||
return ATransfer::template MakeGridDescriptor<padM, padK>(
|
||||
base_desc, M, MPad, K, KPad, StrideAs[i], AK0);
|
||||
base_descs[i], M, MPad, K, KPad, StrideAs[i], AK0);
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
@@ -539,8 +539,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return ATransfer::template MakeGridDescriptor<padM, padK>(base_desc, M, M, K, K, 0, AK0);
|
||||
}
|
||||
|
||||
template <typename BaseDescriptors_M_K>
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
|
||||
MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs, const index_t KBatch = 1)
|
||||
{
|
||||
const index_t M = base_descs.At(I0).GetLength(I0);
|
||||
const index_t K = base_descs.At(I0).GetLength(I1);
|
||||
|
||||
const index_t MPad = CalculateMPadded(M);
|
||||
const index_t KPad = CalculateKPadded(K, KBatch);
|
||||
|
||||
const index_t AK0 = CalculateAK0Padded(K, KBatch);
|
||||
|
||||
return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, {}, AK0);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAsGridDescriptor_AK0_M_AK1(const index_t M,
|
||||
const index_t MPad,
|
||||
const index_t K,
|
||||
const index_t KPad,
|
||||
const std::array<index_t, NumATensor>& StrideAs,
|
||||
const index_t AK0)
|
||||
{
|
||||
const auto base_descs =
|
||||
generate_tuple([&](auto i) { return MakeAGridDescriptor_M_K(M, K, StrideAs[i]); },
|
||||
Number<NumATensor>{});
|
||||
return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, StrideAs, AK0);
|
||||
}
|
||||
|
||||
template <typename BaseDescriptors_N_K>
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs,
|
||||
const index_t K,
|
||||
const index_t KPad,
|
||||
const index_t N,
|
||||
const index_t NPad,
|
||||
@@ -558,9 +589,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
GemmSpec == GemmSpecialization::MKPadding;
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto base_desc = MakeBGridDescriptor_N_K(N, K, StrideBs[i]);
|
||||
return BTransfer::template MakeGridDescriptor<padN, padK>(
|
||||
base_desc, N, NPad, K, KPad, StrideBs[i], BK0);
|
||||
base_descs[i], N, NPad, K, KPad, StrideBs[i], BK0);
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
@@ -578,6 +608,36 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return BTransfer::template MakeGridDescriptor<padN, padK>(base_desc, N, N, K, K, 0, BK0);
|
||||
}
|
||||
|
||||
template <typename BaseDescriptors_N_K>
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs, const index_t KBatch = 1)
|
||||
{
|
||||
const index_t N = base_descs.At(I0).GetLength(I0);
|
||||
const index_t K = base_descs.At(I0).GetLength(I1);
|
||||
|
||||
const index_t NPad = CalculateNPadded(N);
|
||||
const index_t KPad = CalculateKPadded(K, KBatch);
|
||||
|
||||
const index_t BK0 = CalculateBK0Padded(K, KBatch);
|
||||
|
||||
return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, {}, BK0);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
|
||||
const index_t KPad,
|
||||
const index_t N,
|
||||
const index_t NPad,
|
||||
const std::array<index_t, NumBTensor>& StrideBs,
|
||||
const index_t BK0)
|
||||
{
|
||||
|
||||
const auto base_descs =
|
||||
generate_tuple([&](auto i) { return MakeBGridDescriptor_N_K(N, K, StrideBs[i]); },
|
||||
Number<NumBTensor>{});
|
||||
return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, StrideBs, BK0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
@@ -681,7 +741,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
|
||||
template <typename DsGridDesc>
|
||||
__device__ __host__ static constexpr auto
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n,
|
||||
index_t MBlock,
|
||||
index_t NBlock)
|
||||
|
||||
Reference in New Issue
Block a user