Finalize conv specialization for filter 3x3, pad 1, stride 1, dilation 1 case.

This commit is contained in:
Ville Pietilä
2026-02-03 04:10:09 -05:00
parent a814ba15fd
commit d132df2bf5
2 changed files with 53 additions and 7 deletions

View File

@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
InElementOp,
WeiElementOp,
OutElementOp,
//ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
//ConvSpec, // ConvForwardSpecialization
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1_32_4_4_200x200,
GemmSpec, // GemmSpecialization
256, // BlockSize
@@ -108,7 +108,7 @@ using DeviceConvFwdInstance =
S<1, 32, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, must match with num merged groups
1, // Vector load/store size for output tensor = CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::BlockGemmPipelineVersion::v1,
InKernelDataType,
WeiKernelDataType,
false, // No direct load