mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Ck profiler splitk (#857)
* updated regular gemm * update ckProfiler * fixed gtests --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -58,7 +58,9 @@ template <typename ADataType,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename ComputeType = CDataType>
|
||||
typename ComputeType = CDataType,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
|
||||
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -77,7 +79,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
|
||||
static constexpr PipelineVersion PipelineVer = PipelineVersion::v1;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
|
||||
@@ -114,7 +114,8 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
// Current implementation does not support multiple D fusions.
|
||||
enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>>,
|
||||
@@ -183,7 +184,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVersion::v2>;
|
||||
PipelineVer>;
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using Block2ETileMapKSplit =
|
||||
|
||||
@@ -789,53 +789,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
#if 0
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k0_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (karg.K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#else
|
||||
// gridwise GEMM pipeline
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
|
||||
@@ -858,7 +811,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
#endif
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user