mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Fix transform and instances for grouped conv bwd data (#848)
* Fix transform and instances for grouped conv bwd data * Add instances for small K and small C * Remove workaround after fix * Fix interface tests
This commit is contained in:
@@ -266,12 +266,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto AK = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto BK = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
// check consistency of desc
|
||||
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
|
||||
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// check tile size
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
const auto num_k_loop = AK / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user