mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Implement padding and sanity checks for fused GEMM+GEMM (#376)
* GemmPadder and GemmGemmPadder * proper padding using GemmGemmPadder * test gemm_gemm padding * properly check size K in IsSupportedArgument() * properly check size requirement given SrcScalarPerVector in IsSupportedArgument() * comment * format
This commit is contained in:
@@ -200,7 +200,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const std::vector<index_t>& lengths_m_n_k_o)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
@@ -216,6 +217,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
|
||||
const auto KRaw = lengths_m_n_k_o[2];
|
||||
if(!(KRaw % AK1 == 0 && KRaw % BK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
|
||||
Gemm1N % Gemm1NPerBlock == 0))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user