mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement batched gemm add relu gemm add for rdna4 (#3391)
* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation * wip: many fixes in implementation of batched gemm gemm multiple d * wip: batched gemm gemm multiple d gridwise op compiling, not working yet * fix: incorrect d0 grid indexing in batched gemm gemm multipled * feat: add instances for batched gemm add relu gemm add * chore: configure instance with low vector transfer size for odd sizes * chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense * fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes * fix: disable odd size tests on XDL archs * chore: removed temporary logging * chore: update some references to C tensor to E tensor * Tentative fix for example template params * Tentative fix for non-multi-D batched gemm gemm device impl. * Tentative fix for xdl example template params * Tentative fix for profiler build on gfx90a * chore: improve device batched gemm gemm multi D comment to include all ops and dimensions * chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply * fix: make the gemm1 data types match what happens in the device op * feat: add d0s/d1s datatypes and layouts to the device op type string * chore: change element-wise op so addition happens in fp32 * chore: add static asserts for gemm0/gemm1 calculated wave sizes * chore: also updated other element-wise ops to use fp32 calculations * chore: log number of supported instances * chore: update instance comment * chore: disable kernel timing in example by default * fix: gemm1 wave size calculation * fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions * chore: remove increased tolerance in batched gemm gemm multiple d example * chore: add comment explaining that verification fails for certain input values * chore: clarify instance comment --------- Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -51,12 +52,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
Tuple<>{}, // p_d0s_grid
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
Tuple<>{}, // p_d1s_grid
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
@@ -240,8 +245,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// DataType Family
|
||||
ADataType,
|
||||
B0DataType,
|
||||
Tuple<>, // Ds0DataType
|
||||
AccDataType, // Acc0DataType
|
||||
B1DataType,
|
||||
Tuple<>, // Ds1DataType
|
||||
AccDataType, // Acc1DataType
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
@@ -255,7 +262,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
Tuple<>, // Ds0GridDesc
|
||||
B1GridDesc,
|
||||
Tuple<>, // Ds1GridDesc
|
||||
CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
@@ -290,6 +299,7 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B0BlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
B0BlockLdsAddExtraL,
|
||||
1, // CDE0BlockTransferSrcScalarPerVector
|
||||
B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
@@ -369,8 +379,8 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
|
||||
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
@@ -405,10 +415,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
@@ -500,7 +510,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -825,6 +825,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<A0DataType, B0DataType, Gemm0MPerXdl, Gemm0NPerXdl>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "wrong! XDL/WMMA not supported for these datatypes or operation sizes."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -843,6 +848,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
|
||||
{
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "wrong! Unsupported tensor layout combination." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -101,6 +101,15 @@ struct GemmGemmPadder
|
||||
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
|
||||
}
|
||||
|
||||
// D0[M, N]
|
||||
template <typename D0Desc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw& d0_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor(
|
||||
d0_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
|
||||
}
|
||||
|
||||
// B1[Gemm1N, Gemm1K] = B1[O, N]
|
||||
template <typename B1Desc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
|
||||
Reference in New Issue
Block a user