mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Device Op GroupedGemmMultipleD + example fp16 (#633)
* Pass shared mem pointer as pointer to void. * Device Op GroupedGEMM Multiple D * Example for grouped gemm multiple d. * Add MI200 to supported archs. --------- Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -238,19 +238,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
|
||||
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
|
||||
using CGridDesc_M0_M10_M11_N0_N10_N11 =
|
||||
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
|
||||
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
bool HasDoubleTailKBlockLoop,
|
||||
typename Block2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
void* __restrict__ p_shared_block,
|
||||
const AElementwiseOperation&,
|
||||
const BElementwiseOperation&,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
@@ -399,8 +399,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
FloatAB* p_a_block_double = static_cast<FloatAB*>(p_shared_block);
|
||||
FloatAB* p_b_block_double =
|
||||
static_cast<FloatAB*>(p_shared_block) + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
|
||||
Reference in New Issue
Block a user