mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Make sure that GEMM sizes in K dimension are supported. (#527)
* apply new K-dimension check in gemm_xdl_cshuffle * add K-dim check to gemm_xdl and batched_gemm_xdl * fix syntax * fix syntax * clean-up the debug output
This commit is contained in:
@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
kraw_{K}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t kraw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.kraw_ % K1 != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
|
||||
@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
|
||||
<< " ) " << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
|
||||
@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
kraw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t kraw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(arg.kraw_ % K1 != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
|
||||
@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
kraw_{KRaw}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t kraw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -578,6 +580,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.kraw_ % AK1 != 0 || arg.kraw_ % BK1 != 0) &&
|
||||
!(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
|
||||
Reference in New Issue
Block a user