mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[GEMM] Gemm universal device operation (#1154)
* Optimize GEMM on MI200/300: 1. Add new blockwise gemm pipeline 2. Add irregular splitk intances * clang format + typo fix * Fix a bug * initial commit * Add more instances to irregular splitk * blkgemm pipeline v1~4 prototype * Sanity Checked. Known issue: 1. Poor performance of splitk 2. Register spill on blkgemmpipeline v3 * Sanity and Performance fix: 1. fix a bug related to sanity in grouped b2c mapping 2. fix a bug related to sanity and performance in splitk offset * Sanity and API update: 1. Remove prefetch stage 2. Fix valid check bug 3, Add first gemm_universal instance into ckProfiler * Add NN instances for gemm universal * 1. Add NT instances for gemm_universal 2. Fix a bug about Kpadding in gemm_universal * Fix a bug regarding padding Odd K number * remove kernel print * Fix KPadding bug... * Update safety check * another try to fix kpadding.. * Sanity checked * new instances.. * clang format+typo fix * remove clang format script's change * Add non-hotloop compile option * 1. Add fp16xfp8 example 2. pull packed convert f8 from pr1150 * Some miscs.. opt and fix * Add pipeline description docs * Split universal gemm instance library to cut profiler compiling time * uncomment cmakefile * Fix a bug caused by blockwise_gemm_pipe_v2 * reduce default splitk to 1 * Add 224x256x64 tile size * update, including: 1. Experiment pipeline 5~7 2. Optimization for pipeline 4 3. Organized instance library * temp save * temp save * Permuted lds layout, sanity and function checked * clang format * Move OOB check from RunRead to RunWrite, for better software pipeline. TODO: agpr spill when NN layout * clangformat * A/B splitpipe scheduler for v3 * Fix two bugs * bug fix * fix a bug in oob check * Example for mixed fp16_fp8 gemm * Clean experimental code blocks * Add mixed precision gemm into profiler * tempsave * optimize m/n major lds layout * Add RRR GEMM mixed precision instances * Optimize f8 matrix transpose * Add test_gemm_universal * A/B spilt schedule for blkpip v5 * Take ds_read2 into iglp scheduling scheme * format * fixed cmake * Add llvm-option into CI cmake flag --------- Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -259,46 +259,20 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
|
||||
BlockToCTileMap_M00_N0_M01Adapt;
|
||||
};
|
||||
|
||||
// Rows of column-vectors
|
||||
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt;
|
||||
// Grouped Rows of column-vectors WGP mapping
|
||||
// Optimized for MI300-like multipe-die chip
|
||||
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(
|
||||
const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
|
||||
operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
|
||||
operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
|
||||
index_t N,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
#if 0
|
||||
if(get_thread_global_1d_id()==0){
|
||||
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
|
||||
: BlockToCTileMap_Grouped_M00_N0_M01Adapt(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
@@ -309,12 +283,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
|
||||
return M0 * N0;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
@@ -329,67 +297,82 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
if(M0 == 1)
|
||||
{
|
||||
return make_tuple(0, block_1d_id);
|
||||
}
|
||||
else if(N0 == 1)
|
||||
{
|
||||
return make_tuple(block_1d_id, 0);
|
||||
}
|
||||
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
else
|
||||
{
|
||||
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
|
||||
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
|
||||
auto group_id_x = block_1d_id % GroupNum;
|
||||
auto group_id_y = block_1d_id / GroupNum;
|
||||
auto remap_block_1d_id =
|
||||
group_id_x <= big_group_num
|
||||
? group_id_x * group_size + group_id_y
|
||||
: group_id_x * group_size + big_group_num - group_id_x + group_id_y;
|
||||
|
||||
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
|
||||
auto group_id = block_1d_id % GroupNum;
|
||||
auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum;
|
||||
index_t idx_N0 = remap_block_1d_id % N0;
|
||||
index_t idx_M0 = remap_block_1d_id / N0;
|
||||
|
||||
index_t idx_N0 = remap_block_1d_id % N0;
|
||||
index_t idx_M0 = remap_block_1d_id / N0;
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
/**
|
||||
* idxN0
|
||||
*
|
||||
* |< mtx N >|
|
||||
*
|
||||
* NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
* N_0 N_1 N_2 N_3
|
||||
* - |-----------|-----------|-----------|-----|-----|-
|
||||
* ^ | - - 0 |/----> 2 | | | |
|
||||
* | | | / | | | | | M_0 MPerBlock
|
||||
* | M | /| | | | | |
|
||||
* |-0---|---/-|-----|-----|-----------|-----|-----|-
|
||||
* | 1 | / | | | blockid | | |
|
||||
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
|
||||
* | - V 1 | - 3 | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* mtx M | | | | | |
|
||||
* | | | | | | M_2 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* | | | | | |
|
||||
* | | | | | | M_3 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* V | | | | | |
|
||||
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* Example:
|
||||
* assume:
|
||||
* M0 = 5
|
||||
* N0 = 4
|
||||
* block_1d_id = 5
|
||||
* M01 = 2
|
||||
*
|
||||
* idx_N0 = 1
|
||||
* idx_M0 = 1
|
||||
* M01_adapt = 2
|
||||
* idx_M00 = 0
|
||||
* idx_M01 = 1
|
||||
* idx_N0_M01_local = 5
|
||||
* output {1, 2}
|
||||
*/
|
||||
|
||||
/**
|
||||
* idxN0
|
||||
*
|
||||
* |< mtx N >|
|
||||
*
|
||||
* NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
* N_0 N_1 N_2 N_3
|
||||
* - |-----------|-----------|-----------|-----|-----|-
|
||||
* ^ | - - 0 |/----> 2 | | | |
|
||||
* | | | / | | | | | M_0 MPerBlock
|
||||
* | M | /| | | | | |
|
||||
* |-0---|---/-|-----|-----|-----------|-----|-----|-
|
||||
* | 1 | / | | | blockid | | |
|
||||
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
|
||||
* | - V 1 | - 3 | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* mtx M | | | | | |
|
||||
* | | | | | | M_2 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* | | | | | |
|
||||
* | | | | | | M_3 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* V | | | | | |
|
||||
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* Example:
|
||||
* assume:
|
||||
* M0 = 5
|
||||
* N0 = 4
|
||||
* block_1d_id = 5
|
||||
* M01 = 2
|
||||
*
|
||||
* idx_N0 = 1
|
||||
* idx_M0 = 1
|
||||
* M01_adapt = 2
|
||||
* idx_M00 = 0
|
||||
* idx_M01 = 1
|
||||
* idx_N0_M01_local = 5
|
||||
* output {1, 2}
|
||||
*/
|
||||
|
||||
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
@@ -405,15 +388,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
// keep the redundant type argument for backward compatibility
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
: BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
|
||||
{
|
||||
using BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>::
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt;
|
||||
};
|
||||
|
||||
// columns of row-vectors
|
||||
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user