Fix transform and instances for grouped conv bwd data (#848)

* Fix transform and instances for grouped conv bwd data

* Add instances for small K and small C

* Remove workaround after fix

* Fix interface tests
This commit is contained in:
Bartłomiej Kocot
2023-08-22 18:25:41 +02:00
committed by GitHub
parent eac50708d9
commit 595d23be14
8 changed files with 203 additions and 160 deletions

View File

@@ -266,12 +266,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{