mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Multiple fixes to GroupedGemm+SplitK (#707)
* Add license header. * Reduce number of logged output. Add constant initialization. * Add functional tests for grouped_gemm with different kbatch value. * Add debug log informations + remove unused code. * Don't pass kbatch to CalculateKPadded. * Turn on logging in grouped gemm and gemm splitk profiler * Debug: limit number of test cases to run; * Log more information and initialize with constant value. * Turn on DEBUG_LOG * Add more debug log informations. * Limit the number of instances to compile. * Use GridwiseGemmPipeline * Use KBatch to calculate K0 * Multiple DebugLog messages. * Unit tests for multiple KBatch values. * Refactoring * Disable logging * extract out of if statement KBatch update. * Uncomment instances. * Disable DebugLog. * Use Kbatch when calculate KPadded. * Fix CGridDesc padding. * Use available helper functions. * Uncomment code commented for debuggin. * Remove unnecessary debug log messages. * Uncomment previously commented code for debug purposes. * Add KBatch info to profiler output summary log. * Add gtests for gemm splitk using ckProfiler API. * Add more test-cases for different data layout. * Add more test cases for gemm splitk * Remove old test. * Unit tests for MKNK ggemm interface. * Fix and add more unit-tests. * Constepxr everything! * Increase error threshold for fp16 and splitk. Since we're using fp16 atomic add for splitk there's a known precision loss. --------- Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -73,6 +73,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
|
||||
static constexpr PipelineVersion PipelineVer = PipelineVersion::v2;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
@@ -85,6 +90,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -112,7 +118,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
|
||||
@@ -257,7 +265,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
StrideC,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
KBatch};
|
||||
}
|
||||
@@ -290,7 +298,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
StrideC,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K),
|
||||
GridwiseGemm::CalculateKPadded(K, KBatch),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
KBatch);
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -152,6 +152,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -179,7 +180,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVersion::v2>;
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
|
||||
using Block2ETileMapKSplit =
|
||||
@@ -265,8 +268,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(M, N, m_padded, n_padded, stride_c);
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
@@ -319,8 +321,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
|
||||
karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
@@ -501,6 +503,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
#endif // DEBUG_LOG
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -509,14 +516,15 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
#if DEBUG_LOG
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i << " is not supported!\n";
|
||||
#if DEBUG_LOG
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
a.Print();
|
||||
}
|
||||
#endif // DEBUG_LOG
|
||||
supported &= group_arg_valid;
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user