mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
refactor
This commit is contained in:
@@ -276,13 +276,14 @@ struct DeviceGemmXdl_C_Shuffle
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
|
||||
@@ -113,7 +113,7 @@ template <
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
index_t NumPrefetch = 1>
|
||||
index_t NumGemmKPrefetchStage = 1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -131,6 +131,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
constexpr auto max_lds_align = AK1;
|
||||
@@ -246,21 +250,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check NumPrefetch
|
||||
if constexpr(NumPrefetch == 1)
|
||||
{
|
||||
// 1-stage prefetch always supported
|
||||
}
|
||||
else if constexpr(NumPrefetch == 2)
|
||||
{
|
||||
// 2-stage prefetch currently only support even number of K0 loop
|
||||
// TODO: add support for odd number of K0 loop
|
||||
if(!((K / KPerBlock) % 2 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -290,12 +283,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
// TODO move this function into GEMM-pipeline class
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1;
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return has_main_k0_block_loop;
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -434,7 +426,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
NumGemmKPrefetchStage>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
@@ -465,7 +457,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumPrefetch>(
|
||||
NumGemmKPrefetchStage>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
@@ -484,7 +476,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
@@ -512,43 +504,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumPrefetch,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
GridwiseGemmPipe::template Run<HasMainK0BlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user