Make sure that GEMM sizes in K dimension are supported. (#527)

* apply new K-dimension check in gemm_xdl_cshuffle

* add K-dim check to gemm_xdl and batched_gemm_xdl

* fix syntax

* fix syntax

* clean-up the debug output
This commit is contained in:
Illia Silin
2022-12-08 09:48:43 -08:00
committed by GitHub
parent 614a7b1bb0
commit d58b7f5155
4 changed files with 32 additions and 3 deletions

View File

@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
kraw_{K}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t kraw_;
};
// Invoker
@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,

View File

@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< " ) " << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],

View File

@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
kraw_{K}
{
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t kraw_;
};
// Invoker
@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false;
}
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,

View File

@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
kraw_{KRaw}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t kraw_;
};
// Invoker
@@ -578,6 +580,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return false;
}
if((arg.kraw_ % AK1 != 0 || arg.kraw_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,