mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Grouped Gemm with Fixed K and N with SplitK (#818)
* move all arguments into device * add b2c_tile_map * add examples * add SetDeviceKernelArgs * dedicated fixed_nk solution * init client api * add grouped_gemm_bias example * add a instance * add instances * formatting * fixed cmake * Update EnableCompilerWarnings.cmake * Update cmake-ck-dev.sh * clean; fixed comments * fixed comment * add instances for fp32 output * add instances for fp32 output * add fp32 out client example * fixed CI * init commit for kbatch * add splitk gridwise * format * fixed * clean deviceop * clean code * finish splitk * fixed instances * change m_loops to tile_loops * add setkbatch * clean code * add splitK+bias * add instances * opt mk_nk instances * clean examples * fixed CI * remove zero * finished non-zero * clean * clean code * optimized global_barrier * fixed ci * fixed CI * removed AddBias * format * fixed CI * fixed CI * move 20_grouped_gemm to 21_grouped_gemm --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -587,7 +587,8 @@ struct OffsettedBlockToCTileMap
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
|
||||
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
|
||||
index_t block_start)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GEMM:
|
||||
@@ -74,6 +77,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
@@ -330,6 +335,94 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
template <typename ALayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
template <typename BLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
template <typename ELayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
|
||||
{
|
||||
constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
const auto e_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideE));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
template <typename DsLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
@@ -758,6 +851,85 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void Run(const void* __restrict__ p_a_grid_,
|
||||
const void* __restrict__ p_b_grid_,
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
|
||||
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user