mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Codegen: isSupportedArgument check (#1417)
* added isSupportedArgument check into codegen device op * adding function call * remove commented code
This commit is contained in:
@@ -233,6 +233,12 @@ extern "C" __global__ void run_${name}(
|
||||
${BElementwiseOperation}{},
|
||||
${CDEElementwiseOperation}{1.0f, 1.0f});
|
||||
|
||||
if(!DeviceConv::IsSupportedArgument(arg))
|
||||
{
|
||||
printf("Arguement is not supported.\n");
|
||||
return;
|
||||
};
|
||||
|
||||
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
|
||||
|
||||
// GridwiseGemm
|
||||
|
||||
@@ -727,6 +727,181 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
ck::Array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
static __device__ __host__ bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check ConvolutionForwardSpecialization
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t ConvStride = arg.conv_filter_strides_[i];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of A
|
||||
// FIXME: layout
|
||||
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
|
||||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC>)
|
||||
{
|
||||
const index_t C = arg.a_g_n_c_wis_lengths_[2];
|
||||
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
|
||||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
|
||||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
|
||||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
|
||||
is_same_v<BLayout, ctc::KZYXGC>)
|
||||
|
||||
{
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of Ds
|
||||
bool valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
// FIXME: layout
|
||||
if constexpr(is_same_v<DLayout, ctc::G_NW_K> || is_same_v<DLayout, ctc::G_NHW_K> ||
|
||||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
|
||||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
|
||||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
|
||||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
|
||||
{
|
||||
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
|
||||
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<DLayout, ctc::G_K>)
|
||||
{
|
||||
// G and K must be the same
|
||||
if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
|
||||
arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// E and D must have the same shape
|
||||
for(index_t d = 0; d < NDimSpatial + 3; d++)
|
||||
{
|
||||
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
|
||||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
|
||||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
|
||||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
|
||||
is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
const index_t K = arg.e_g_n_k_wos_lengths_[2];
|
||||
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
// Genarate tuples with the same descriptors
|
||||
const auto as_grid_desc_ak0_m_ak1 =
|
||||
generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number<NumATensor>{});
|
||||
const auto bs_grid_desc_bk0_n_bk1 =
|
||||
generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number<NumBTensor>{});
|
||||
return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __host__ auto MakeArgument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
|
||||
Reference in New Issue
Block a user