mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
@@ -8,41 +8,6 @@ namespace driver {
|
||||
|
||||
struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
ck::DataTypeEnum_t ABDataTypeEnum;
|
||||
ck::DataTypeEnum_t AccDataTypeEnum;
|
||||
ck::DataTypeEnum_t CDataTypeEnum;
|
||||
|
||||
int BlockSize;
|
||||
|
||||
int GN0;
|
||||
int GK1;
|
||||
|
||||
int GM1PerBlockGM11;
|
||||
int GN1PerBlockGN11;
|
||||
int GK0PerBlock;
|
||||
|
||||
int BM1PerThreadBM11;
|
||||
int BN1PerThreadBN11;
|
||||
int BK0PerThread;
|
||||
|
||||
std::array<int, 2> BM10BN10ThreadClusterBM10Xs;
|
||||
std::array<int, 2> BM10BN10ThreadClusterBN10Xs;
|
||||
|
||||
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
|
||||
bool HasMainKBlockLoop;
|
||||
bool HasDoubleTailKBlockLoop;
|
||||
|
||||
auto GetCompileParameterString() const
|
||||
{
|
||||
// clang-format off
|
||||
@@ -128,11 +93,46 @@ struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
" -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(CThreadTransferDstScalarPerVector) +
|
||||
" -DCK_PARAM_HasMainKBlockLoop=" +
|
||||
std::to_string(HasMainKBlockLoop) +
|
||||
std::to_string(static_cast<int>(HasMainKBlockLoop)) +
|
||||
" -DCK_PARAM_HasDoubleTailKBlockLoop=" +
|
||||
std::to_string(HasDoubleTailKBlockLoop);
|
||||
std::to_string(static_cast<int>(HasDoubleTailKBlockLoop));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
ck::DataTypeEnum_t ABDataTypeEnum;
|
||||
ck::DataTypeEnum_t AccDataTypeEnum;
|
||||
ck::DataTypeEnum_t CDataTypeEnum;
|
||||
|
||||
int BlockSize;
|
||||
|
||||
int GN0;
|
||||
int GK1;
|
||||
|
||||
int GM1PerBlockGM11;
|
||||
int GN1PerBlockGN11;
|
||||
int GK0PerBlock;
|
||||
|
||||
int BM1PerThreadBM11;
|
||||
int BN1PerThreadBN11;
|
||||
int BK0PerThread;
|
||||
|
||||
std::array<int, 2> BM10BN10ThreadClusterBM10Xs;
|
||||
std::array<int, 2> BM10BN10ThreadClusterBN10Xs;
|
||||
|
||||
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
|
||||
bool HasMainKBlockLoop;
|
||||
bool HasDoubleTailKBlockLoop;
|
||||
};
|
||||
|
||||
struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
@@ -230,8 +230,6 @@ struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const int C = conv_problem_desc.C;
|
||||
const int Y = conv_problem_desc.Y;
|
||||
const int X = conv_problem_desc.X;
|
||||
@@ -248,12 +246,17 @@ struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
|
||||
DataTypeEnum_t AccDataTypeEnum;
|
||||
|
||||
switch(ABDataTypeEnum)
|
||||
if(ABDataTypeEnum == DataTypeEnum_t::Float || ABDataTypeEnum == DataTypeEnum_t::Half)
|
||||
{
|
||||
case DataTypeEnum_t::Float:
|
||||
case DataTypeEnum_t::Half: AccDataTypeEnum = DataTypeEnum_t::Float; break;
|
||||
case DataTypeEnum_t::Int8: AccDataTypeEnum = DataTypeEnum_t::Int32; break;
|
||||
default: return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
AccDataTypeEnum = DataTypeEnum_t::Float;
|
||||
}
|
||||
else if(ABDataTypeEnum == DataTypeEnum_t::Int8)
|
||||
{
|
||||
AccDataTypeEnum = DataTypeEnum_t::Int32;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
}
|
||||
|
||||
const int BlockSize = tunable.BlockSize;
|
||||
@@ -343,7 +346,7 @@ struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw())
|
||||
{
|
||||
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param;
|
||||
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param{};
|
||||
bool found = false;
|
||||
|
||||
std::tie(compile_param, found) =
|
||||
@@ -369,8 +372,6 @@ struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const int N = conv_problem_desc.N;
|
||||
const int K = conv_problem_desc.K;
|
||||
const int C = conv_problem_desc.C;
|
||||
|
||||
Reference in New Issue
Block a user