mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Fix DL GEMM instances with too large vector size (#901)
* Fix vector lengths of DL GEMM instances with padding * Add checks for correctness of vector lenghts in DL GEMM
This commit is contained in:
committed by
GitHub
parent
f17af2e9ed
commit
63cd459248
@@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
M_raw_{M},
|
||||
N_raw_{N},
|
||||
K_raw_{K},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
@@ -314,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
|
||||
index_t M_raw_;
|
||||
index_t N_raw_;
|
||||
index_t K_raw_;
|
||||
|
||||
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
@@ -485,6 +492,50 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Make sure that the M, N, K dimensions before padding are divisible by respective vector
|
||||
// lengths.
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
constexpr auto A_K_vec_length =
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
|
||||
if(arg.K_raw_ % A_K_vec_length != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_M_vec_lenght =
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
|
||||
if(arg.M_raw_ % A_M_vec_lenght != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
constexpr auto B_N_vec_lenght =
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
|
||||
if(arg.N_raw_ % B_N_vec_lenght != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_K_vec_length =
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
|
||||
if(arg.K_raw_ % B_K_vec_length != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
|
||||
Reference in New Issue
Block a user