mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
adding thread group
This commit is contained in:
@@ -27,7 +27,8 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t ABBlockTransferThreadGroupSize,
|
||||
index_t BlockGemmThreadGroupSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
@@ -346,7 +347,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
ABBlockTransferThreadGroupSize,
|
||||
BlockGemmThreadGroupSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
@@ -487,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
{
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
@@ -502,22 +504,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -539,7 +541,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
{
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
@@ -554,22 +556,22 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -673,7 +675,8 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
// clang-format off
|
||||
str << "DeviceGemm_Xdl_CShuffle_v2"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< ABBlockTransferThreadGroupSize << ", "
|
||||
<< BlockGemmThreadGroupSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename AGridDesc,
|
||||
template <typename ABBlockTransferThreadGroup,
|
||||
typename BlockGemmThreadGroup,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
@@ -23,7 +25,9 @@ template <typename AGridDesc,
|
||||
struct GridwiseGemmPipeline_v2;
|
||||
|
||||
// 1-stage prefetch
|
||||
template <typename AGridDesc,
|
||||
template <typename ABBlockTransferThreadGroup,
|
||||
typename BlockGemmThreadGroup,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
@@ -38,7 +42,9 @@ template <typename AGridDesc,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer,
|
||||
bool HasMainLoop>
|
||||
struct GridwiseGemmPipeline_v2<AGridDesc,
|
||||
struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
BlockGemmThreadGroup,
|
||||
AGridDesc,
|
||||
ABlockDesc,
|
||||
ABlockTransfer,
|
||||
AGridBuffer,
|
||||
@@ -58,19 +64,24 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static __device__ void RunProducer(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
index_t num_loop)
|
||||
__device__ constexpr GridwiseGemmPipeline_v2()
|
||||
{
|
||||
// TODO static assert
|
||||
}
|
||||
|
||||
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
index_t num_loop)
|
||||
{
|
||||
// global read 0
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
@@ -140,11 +151,11 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void RunConsumer(ABlockBuffer& a_block_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
@@ -193,6 +204,45 @@ struct GridwiseGemmPipeline_v2<AGridDesc,
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
if(ABBlockTransferThreadGroup::IsBelong())
|
||||
{
|
||||
gridwise_gemm_pipeline.RunABBlockTransferPipeline(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
num_loop);
|
||||
}
|
||||
else if(BlockGemmThreadGroup::IsBelong())
|
||||
{
|
||||
gridwise_gemm_pipeline.RunBlockGemmPipeline(
|
||||
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_loop);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -67,7 +67,8 @@ template <typename FloatAB,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t ABBlockTransferThreadGroupSize,
|
||||
index_t BlockGemmThreadGroupSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
@@ -114,6 +115,50 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
|
||||
using ThisThreadBlock =
|
||||
AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
|
||||
|
||||
#if 1
|
||||
using ABBlockTransferThreadGroup = ThisThreadBlock;
|
||||
using BlockGemmThreadGroup = ThisThreadBlock;
|
||||
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
|
||||
#else
|
||||
struct ABBlockTransferThreadGroup
|
||||
{
|
||||
__device__ static constexpr index_t GetNumOfThread()
|
||||
{
|
||||
return ABBlockTransferThreadGroupSize;
|
||||
}
|
||||
|
||||
__device__ static constexpr bool IsBelong()
|
||||
{
|
||||
return get_thread_local_1d_id() < ABBlockTransferThreadGroupSize;
|
||||
}
|
||||
|
||||
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
|
||||
};
|
||||
|
||||
struct BlockGemmThreadGroup
|
||||
{
|
||||
__device__ static constexpr index_t GetNumOfThread()
|
||||
{
|
||||
return ABBlockTransferThreadGroupSize;
|
||||
}
|
||||
|
||||
__device__ static constexpr bool IsBelong()
|
||||
{
|
||||
return get_thread_local_1d_id() >= ABBlockTransferThreadGroupSize;
|
||||
}
|
||||
|
||||
__device__ static index_t GetThreadId()
|
||||
{
|
||||
return get_thread_local_1d_id() - ABBlockTransferThreadGroupSize;
|
||||
}
|
||||
};
|
||||
|
||||
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -345,11 +390,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
using ThisThreadBlock = AnyThreadBlock<BlockSize>;
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ABBlockTransferThreadGroup,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
@@ -380,7 +423,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ABBlockTransferThreadGroup,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
@@ -420,7 +463,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockGemmThreadGroup,
|
||||
FloatAB,
|
||||
FloatGemmAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
@@ -447,6 +490,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
#if 1
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
|
||||
@@ -465,10 +513,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumGemmKPrefetchStage,
|
||||
HasMainK0BlockLoop>{};
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
#else
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
BlockGemmThreadGroup,
|
||||
remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
|
||||
remove_cvref_t<decltype(a_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(a_grid_buf)>,
|
||||
remove_cvref_t<decltype(a_block_buf)>,
|
||||
remove_cvref_t<decltype(a_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
|
||||
remove_cvref_t<decltype(b_blockwise_copy)>,
|
||||
remove_cvref_t<decltype(b_grid_buf)>,
|
||||
remove_cvref_t<decltype(b_block_buf)>,
|
||||
remove_cvref_t<decltype(b_block_slice_copy_step)>,
|
||||
remove_cvref_t<decltype(blockwise_gemm)>,
|
||||
remove_cvref_t<decltype(c_thread_buf)>,
|
||||
NumGemmKPrefetchStage,
|
||||
HasMainK0BlockLoop>{};
|
||||
#endif
|
||||
|
||||
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
@@ -601,7 +667,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // index_t BlockSize,
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
@@ -655,22 +721,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
if(BlockGemmThreadGroup::IsBelong())
|
||||
{
|
||||
// thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
}
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
if(CShuffleBlockTransferThreadGroup::IsBelong())
|
||||
{
|
||||
// block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user