Multiple fixes to GroupedGemm+SplitK (#707)

* Add license header.

* Reduce number of logged output. Add constant initialization.

* Add functional tests for grouped_gemm with different kbatch value.

* Add debug log informations + remove unused code.

* Don't pass kbatch to CalculateKPadded.

* Turn on logging in grouped gemm and gemm splitk profiler

* Debug: limit number of test cases to run;

* Log more information and initialize with constant value.

* Turn on DEBUG_LOG

* Add more debug log informations.

* Limit the number of instances to compile.

* Use GridwiseGemmPipeline

* Use KBatch to calculate K0

* Multiple DebugLog messages.

* Unit tests for multiple KBatch values.

* Refactoring

* Disable logging
* extract out of if statement KBatch update.

* Uncomment instances.

* Disable DebugLog.

* Use Kbatch when calculate KPadded.

* Fix CGridDesc padding.

* Use available helper functions.

* Uncomment code commented for debuggin.

* Remove unnecessary debug log messages.

* Uncomment previously commented code for debug purposes.

* Add KBatch info to profiler output summary log.

* Add gtests for gemm splitk using ckProfiler API.

* Add more test-cases for different data layout.

* Add more test cases for gemm splitk

* Remove old test.

* Unit tests for MKNK ggemm interface.

* Fix and add more unit-tests.

* Constepxr everything!

* Increase error threshold for fp16 and splitk.

Since we're using fp16 atomic add for splitk there's a
known precision loss.

---------

Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
Adam Osewski
2023-05-30 14:09:06 +02:00
committed by GitHub
parent c2d7a29dec
commit 70e4eb567f
20 changed files with 1263 additions and 471 deletions

View File

@@ -73,6 +73,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
static constexpr PipelineVersion PipelineVer = PipelineVersion::v2;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
@@ -85,6 +90,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
NumGemmKPrefetchStage,
MPerBlock,
NPerBlock,
K0PerBlock,
@@ -112,7 +118,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
@@ -257,7 +265,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
KBatch};
}
@@ -290,7 +298,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
KBatch);
}

View File

@@ -85,7 +85,7 @@ template <typename ALayout,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
@@ -152,6 +152,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
NumGemmKPrefetchStage,
MPerBlock,
NPerBlock,
K0PerBlock,
@@ -179,7 +180,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferScalarPerVector_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVersion::v2>;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
@@ -265,8 +268,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, N, m_padded, n_padded, stride_c);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
@@ -319,8 +321,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
@@ -501,6 +503,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
return false;
}
@@ -509,14 +516,15 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
#if DEBUG_LOG
if(not group_arg_valid)
{
std::cout << "[" << __func__ << "] group id: " << i << " is not supported!\n";
#if DEBUG_LOG
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
a.Print();
}
#endif // DEBUG_LOG
supported &= group_arg_valid;
}
supported = supported && group_arg_valid;
}
return supported;
}