mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Finalize conv specialization for filter 3x3, pad 1, stride 1, dilation 1 case.
This commit is contained in:
@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
//ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
//ConvSpec, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1_32_4_4_200x200,
|
||||
GemmSpec, // GemmSpecialization
|
||||
256, // BlockSize
|
||||
@@ -108,7 +108,7 @@ using DeviceConvFwdInstance =
|
||||
S<1, 32, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, must match with num merged groups
|
||||
1, // Vector load/store size for output tensor = CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Interwave,
|
||||
ck::BlockGemmPipelineVersion::v2,
|
||||
ck::BlockGemmPipelineVersion::v1,
|
||||
InKernelDataType,
|
||||
WeiKernelDataType,
|
||||
false, // No direct load
|
||||
|
||||
@@ -882,12 +882,12 @@ struct TransformConvFwdToGemm
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
constexpr auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, H, W, NumGroupsToMerge, C),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
NStride, HiStride, WiStride, GStride, CStride));
|
||||
|
||||
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
|
||||
constexpr auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_groups_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(H, Pad1, Pad1),
|
||||
@@ -899,7 +899,7 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
|
||||
constexpr auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_groups_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, H),
|
||||
@@ -916,12 +916,14 @@ struct TransformConvFwdToGemm
|
||||
Sequence<5>{},
|
||||
Sequence<6>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
constexpr auto gemmm_gemmn_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_groups_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, H, W, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return gemmm_gemmn_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
@@ -1477,6 +1479,50 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr (ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1_32_4_4_200x200 &&
|
||||
NumGroupsToMerge > 1)
|
||||
{
|
||||
constexpr ck::index_t C = 4;
|
||||
constexpr ck::index_t K = 4;
|
||||
|
||||
using FilterSizeNumType =
|
||||
ck::conditional_t<NDimSpatial == 1,
|
||||
Number<3*C>,
|
||||
ck::conditional_t<NDimSpatial == 2, Number<9*C>, Number<27*C>>>;
|
||||
|
||||
constexpr ck::index_t KStrideTensorB = FilterSizeNumType{};
|
||||
constexpr ck::index_t GStrideTensorB = Number<FilterSizeNumType{} * K>{};
|
||||
constexpr ck::index_t CStrideTensorB = 1;
|
||||
|
||||
// Ensure the strides match with the expected values
|
||||
if (KStrideTensorB != this->KStrideTensorB_ ||
|
||||
GStrideTensorB != this->GStrideTensorB_ ||
|
||||
CStrideTensorB != this->CStrideTensorB_)
|
||||
{
|
||||
std::stringstream oss;
|
||||
oss << "Error: Stride mismatch in MakeBDescriptor_N_K for special case. "
|
||||
<< "Expected KStrideTensorB: " << KStrideTensorB
|
||||
<< ", Actual: " << this->KStrideTensorB_ << ". "
|
||||
<< "Expected GStrideTensorB: " << GStrideTensorB
|
||||
<< ", Actual: " << this->GStrideTensorB_ << ". "
|
||||
<< "Expected CStrideTensorB: " << CStrideTensorB
|
||||
<< ", Actual: " << this->CStrideTensorB_ << ".";
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
|
||||
constexpr auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}),
|
||||
make_tuple(KStrideTensorB, GStrideTensorB, CStrideTensorB));
|
||||
constexpr auto gemmn_gemm_k_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)),
|
||||
make_pass_through_transform(FilterSizeNumType{})),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return gemmn_gemm_k_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
|
||||
Reference in New Issue
Block a user