mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Grouped Gemm + SplitK + simplified Kernel Args (#669)
* simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * B2C with 3D grid for KSplit * Remove unused code. * Use default B2C (3D grid) in grid gemm v2r4r2. * Device gemm splitk use B2C map. * Device GroupedGemmXdlSplitKCShuffle * Example for GroupedGemm Xdl SplitK * Introduce Device GroupedGemmSplitK * Fix updating kbatch size. * Add instance mk-nk-mn * Enable set kbatch in profiler. * Add GGemmSplitK mk-kn-mn instances * Add more instances & split into multiple files. * minor fix * tuning * clean * disabled failed instances * use pipe v2 * Ignore arg on not supported arch. * fix warning --------- Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Jing Zhang <jizhan@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
This commit is contained in:
@@ -587,4 +587,52 @@ struct OffsettedBlockToCTileMap
|
||||
index_t block_start_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Simple tile mapping which creates 3D grid of block of threads.
|
||||
*
|
||||
* @paragraph Description
|
||||
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
|
||||
* blocks. The first 2D are regular 2D tiles created by division of output GEMM
|
||||
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
|
||||
* which denotes the number of blocks we use to divide work on GEMM K dimension onto.
|
||||
*
|
||||
* @tparam MPerBlock Output block tile size in M dimension.
|
||||
* @tparam NPerBlock Output block tile size in N dimension.
|
||||
*/
|
||||
template <index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_3DGrid_KSplit
|
||||
{
|
||||
|
||||
__host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
|
||||
|
||||
__host__ __device__ constexpr auto
|
||||
CalculateGridSize(index_t M, index_t N, index_t k_split) const
|
||||
{
|
||||
// Create 3D grid
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return std::make_tuple(N0, M0, k_split);
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
|
||||
{
|
||||
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -15,16 +15,20 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename Block2CTileMap>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
|
||||
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
|
||||
const Block2CTileMap& b2c_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
@@ -32,9 +36,10 @@ __global__ void
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
karg, static_cast<void*>(p_shared));
|
||||
karg, static_cast<void*>(p_shared), b2c_map);
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = b2c_map;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -478,8 +483,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
|
||||
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
|
||||
{
|
||||
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
|
||||
}
|
||||
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
|
||||
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename Block2CTileMap>
|
||||
__device__ static void Run(const Argument& karg,
|
||||
void* __restrict__ p_shared_block,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const FloatAB* p_a_grid = karg.p_a_grid;
|
||||
const FloatAB* p_b_grid = karg.p_b_grid;
|
||||
@@ -504,11 +522,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
// divide block work by [KBatch, M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
|
||||
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
@@ -651,6 +679,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
#if 1
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
@@ -662,6 +691,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
#else
|
||||
auto blockwise_gemm = BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
#endif
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -680,6 +723,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
#if 0
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
@@ -725,6 +769,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#else
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_Selector<PipelineVersion::v2, 1, LoopScheduler::Default>();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
|
||||
(K0PerBlock * K1));
|
||||
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
#endif
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user