mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add 2GB limitation for grouped conv bwd weight (#3054)
This commit is contained in:
@@ -1886,6 +1886,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -1417,6 +1417,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1359,6 +1359,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(karg.M * karg.K * sizeof(ADataType) <= TwoGB &&
|
||||
karg.N * karg.K * sizeof(BDataType) <= TwoGB &&
|
||||
karg.M * karg.N * sizeof(CDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -581,6 +581,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_b_k0_m_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
|
||||
b_b_k0_n_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB &&
|
||||
c_m_n_grid_desc.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user