mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5284 (commit 76b5b15)
[CK_BUILDER] Add DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 to CK Builder (#5284) Add factory, InstanceTraits, and conv traits support for the WMMA V3 forward convolution kernel, enabling the CK Builder to generate and dispatch this kernel variant used by MIOpen on gfx11/gfx12 GPUs. ## Motivation As reported in issue #4944, MIOpen includes WMMA V3 forward convolution kernels, so this PR adds support for those kernels similarly to other supported kernels. ## Technical Details This follows the same implementation as the other kernels. I added some support for reflection, but I left a few todos since we need to generalize our convolution traits to generalize across WMMA/MFMA and CK/CKTile. ## Test Plan Added faster tests to `ninja smoke-builder` that check the instance-traits logic, and I added longer tests that instantiate kernels, following the existing pattern in other kernals. ## Test Result I tested all code with `ninja check-builder` on a gfx1101 build and ran on gfx1101. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
26d29374e5
commit
9f47b8a63d
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
@@ -626,6 +627,118 @@ TEST(InstanceTraits, WmmaInstanceStringReturnsCorrectFormat)
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraits, WmmaV3InstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpec
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
ck::Sequence<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
ck::Sequence<1, 16, 1, 4>, // CDEBlockTransferClusterLengths
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1>; // BlkGemmPipelineVer
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all template parameters
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",16" // MPerWmma
|
||||
",16" // NPerWmma
|
||||
",4" // MRepeat
|
||||
",2" // NRepeat
|
||||
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMRepeatPerShuffle
|
||||
",1" // CShuffleNRepeatPerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",Intrawave" // BlkGemmPipeSched
|
||||
",v1" // BlkGemmPipelineVer
|
||||
",true" // UseThreadTileTransfer
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
// Verify the generated string matches exactly
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
|
||||
Reference in New Issue
Block a user