mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[GEMM] Optimization for MI200/300. (#1135)
* Optimize GEMM on MI200/300: 1. Add new blockwise gemm pipeline 2. Add irregular splitk intances * clang format + typo fix * Fix a bug
This commit is contained in:
@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
#if 0
|
||||
if(get_thread_global_1d_id()==0){
|
||||
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
|
||||
BlockToCTileMap_M00_N0_M01Adapt;
|
||||
};
|
||||
|
||||
// Rows of column-vectors
|
||||
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt;
|
||||
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(
|
||||
const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
|
||||
operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
|
||||
operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
|
||||
index_t N,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
#if 0
|
||||
if(get_thread_global_1d_id()==0){
|
||||
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
|
||||
: BlockToCTileMap_Grouped_M00_N0_M01Adapt(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
|
||||
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
|
||||
auto group_id = block_1d_id % GroupNum;
|
||||
auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum;
|
||||
|
||||
index_t idx_N0 = remap_block_1d_id % N0;
|
||||
index_t idx_M0 = remap_block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
/**
|
||||
* idxN0
|
||||
*
|
||||
* |< mtx N >|
|
||||
*
|
||||
* NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
* N_0 N_1 N_2 N_3
|
||||
* - |-----------|-----------|-----------|-----|-----|-
|
||||
* ^ | - - 0 |/----> 2 | | | |
|
||||
* | | | / | | | | | M_0 MPerBlock
|
||||
* | M | /| | | | | |
|
||||
* |-0---|---/-|-----|-----|-----------|-----|-----|-
|
||||
* | 1 | / | | | blockid | | |
|
||||
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
|
||||
* | - V 1 | - 3 | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* mtx M | | | | | |
|
||||
* | | | | | | M_2 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* | | | | | |
|
||||
* | | | | | | M_3 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* V | | | | | |
|
||||
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* Example:
|
||||
* assume:
|
||||
* M0 = 5
|
||||
* N0 = 4
|
||||
* block_1d_id = 5
|
||||
* M01 = 2
|
||||
*
|
||||
* idx_N0 = 1
|
||||
* idx_M0 = 1
|
||||
* M01_adapt = 2
|
||||
* idx_M00 = 0
|
||||
* idx_M01 = 1
|
||||
* idx_N0_M01_local = 5
|
||||
* output {1, 2}
|
||||
*/
|
||||
|
||||
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
// keep the redundant type argument for backward compatibility
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
: BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
|
||||
{
|
||||
using BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>::
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt;
|
||||
};
|
||||
|
||||
// columns of row-vectors
|
||||
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
|
||||
struct BlockToCTileMap_N00_M0_N01Adapt;
|
||||
|
||||
template <index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
|
||||
operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default;
|
||||
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
|
||||
operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
|
||||
: M_(M), N_(N), N01_(N01)
|
||||
{
|
||||
#if 0
|
||||
if(get_thread_global_1d_id()==0){
|
||||
printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t N01 = 8)
|
||||
: BlockToCTileMap_N00_M0_N01Adapt(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
|
||||
index_t idx_M0 = block_1d_id % M0;
|
||||
index_t idx_N0 = block_1d_id / M0;
|
||||
|
||||
const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
|
||||
|
||||
index_t idx_N00 = idx_N0 / N01_;
|
||||
index_t idx_N01 = idx_N0 % N01_;
|
||||
index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
|
||||
|
||||
/**
|
||||
* idxN0
|
||||
*
|
||||
* |< mtx N >|
|
||||
*
|
||||
* |<---N01--->|
|
||||
* - |-----------|-----------|-----------|-----|-----|-
|
||||
* ^ | 0 ----------> 1 | | | |
|
||||
* | | / | | | | M_0 MPerBlock
|
||||
* | / | | | |
|
||||
* |------/----------------|-----------|-----|-----|-
|
||||
* | | | | | | |
|
||||
* idxM0 | V | | | | | M_1 MPerBlock
|
||||
* | 2 ----------> 3 | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* mtx M | | blockid | | | |
|
||||
* | | 5 | | | | M_2 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* | | | | | |
|
||||
* | | | | | | M_3 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* V | | | | | |
|
||||
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
* N_0 N_1 N_2 N_3
|
||||
* Example:
|
||||
* assume:
|
||||
* N0 = 5
|
||||
* M0 = 4
|
||||
* block_1d_id = 5
|
||||
* N01 = 2
|
||||
*
|
||||
* idx_M0 = 1
|
||||
* idx_N0 = 1
|
||||
* N01_adapt = 2
|
||||
* idx_N00 = 0
|
||||
* idx_N01 = 1
|
||||
* idx_M0_N01_local = 5
|
||||
* output {2, 1}
|
||||
*/
|
||||
|
||||
return make_tuple(idx_M0_N01_local / N01_adapt,
|
||||
idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t N01_;
|
||||
};
|
||||
|
||||
// 2D slices of column-vectors in 3D space
|
||||
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
|
||||
{
|
||||
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
|
||||
{
|
||||
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
|
||||
Reference in New Issue
Block a user