Finalize conv specialization for filter 3x3, pad 1, stride 1, dilation 1 case.

This commit is contained in:
Ville Pietilä
2026-02-03 04:10:09 -05:00
parent a814ba15fd
commit d132df2bf5
2 changed files with 53 additions and 7 deletions

View File

@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
InElementOp,
WeiElementOp,
OutElementOp,
//ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
//ConvSpec, // ConvForwardSpecialization
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1_32_4_4_200x200,
GemmSpec, // GemmSpecialization
256, // BlockSize
@@ -108,7 +108,7 @@ using DeviceConvFwdInstance =
S<1, 32, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, must match with num merged groups
1, // Vector load/store size for output tensor = CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::BlockGemmPipelineVersion::v1,
InKernelDataType,
WeiKernelDataType,
false, // No direct load

View File

@@ -882,12 +882,12 @@ struct TransformConvFwdToGemm
throw std::runtime_error(oss.str());
}
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
constexpr auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N, H, W, NumGroupsToMerge, C),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
NStride, HiStride, WiStride, GStride, CStride));
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
constexpr auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(H, Pad1, Pad1),
@@ -899,7 +899,7 @@ struct TransformConvFwdToGemm
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
constexpr auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, H),
@@ -916,12 +916,14 @@ struct TransformConvFwdToGemm
Sequence<5>{},
Sequence<6>{}));
return transform_tensor_descriptor(
constexpr auto gemmm_gemmn_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N, H, W, NumGroupsToMerge)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return gemmm_gemmn_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
@@ -1477,6 +1479,50 @@ struct TransformConvFwdToGemm
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr (ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1_32_4_4_200x200 &&
NumGroupsToMerge > 1)
{
constexpr ck::index_t C = 4;
constexpr ck::index_t K = 4;
using FilterSizeNumType =
ck::conditional_t<NDimSpatial == 1,
Number<3*C>,
ck::conditional_t<NDimSpatial == 2, Number<9*C>, Number<27*C>>>;
constexpr ck::index_t KStrideTensorB = FilterSizeNumType{};
constexpr ck::index_t GStrideTensorB = Number<FilterSizeNumType{} * K>{};
constexpr ck::index_t CStrideTensorB = 1;
// Ensure the strides match with the expected values
if (KStrideTensorB != this->KStrideTensorB_ ||
GStrideTensorB != this->GStrideTensorB_ ||
CStrideTensorB != this->CStrideTensorB_)
{
std::stringstream oss;
oss << "Error: Stride mismatch in MakeBDescriptor_N_K for special case. "
<< "Expected KStrideTensorB: " << KStrideTensorB
<< ", Actual: " << this->KStrideTensorB_ << ". "
<< "Expected GStrideTensorB: " << GStrideTensorB
<< ", Actual: " << this->GStrideTensorB_ << ". "
<< "Expected CStrideTensorB: " << CStrideTensorB
<< ", Actual: " << this->CStrideTensorB_ << ".";
throw std::runtime_error(oss.str());
}
constexpr auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}),
make_tuple(KStrideTensorB, GStrideTensorB, CStrideTensorB));
constexpr auto gemmn_gemm_k_desc = transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)),
make_pass_through_transform(FilterSizeNumType{})),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return gemmn_gemm_k_desc;
}
else
{
if constexpr(NumGroupsToMerge == 1)