Add memory index guard in wmma device ops (#667)

This commit is contained in:
Haocong WANG
2023-04-12 04:42:47 +08:00
committed by GitHub
parent f532988713
commit e85178b4ca
2 changed files with 16 additions and 0 deletions

View File

@@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}

View File

@@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
{
return false;
}
return true;
}