[CK_BUILDER] Add DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 to CK Builder (#5284)

Add factory, InstanceTraits, and conv traits support for the WMMA V3
forward convolution kernel, enabling the CK Builder to generate and
dispatch this kernel variant used by MIOpen on gfx11/gfx12 GPUs.

## Motivation

As reported in issue #4944, MIOpen includes WMMA V3 forward convolution
kernels, so this PR adds support for those kernels similarly to other
supported kernels.

## Technical Details

This follows the same implementation as the other kernels. I added some
support for reflection, but I left a few todos since we need to
generalize our convolution traits to generalize across WMMA/MFMA and
CK/CKTile.

## Test Plan

Added faster tests to `ninja smoke-builder` that check the
instance-traits logic, and I added longer tests that instantiate
kernels, following the existing pattern in other kernals.

## Test Result

I tested all code with `ninja check-builder` on a gfx1101 build and ran
on gfx1101.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
John Shumway
2026-03-10 16:41:51 -07:00
committed by GitHub
parent 270c651d3c
commit dc1ea3fb7a
15 changed files with 916 additions and 0 deletions

View File

@@ -146,6 +146,7 @@ set(INSTANCE_STRING_TESTS
if (CK_USE_WMMA)
list(APPEND INSTANCE_STRING_TESTS
test_instance_string_fwd_grp_conv_wmma_v3.cpp
test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp
test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp
test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp
@@ -172,6 +173,13 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
)
if (CK_USE_WMMA)
target_sources(test_ckb_build_fwd_instances PRIVATE
conv/ck/test_ckb_conv_fwd_2d_wmma_v3_fp16.cpp
)
endif()
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
set(BWD_WEIGHT_TESTS

View File

@@ -0,0 +1,105 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/conv/fwd_ck.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::FORWARD,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_64x64x32)
.with_gemm_config(cku::GemmParamsABK1_Wmma_16x16_4x2_per_wave)
.with_transfer(cku::Transfer_4x16x1)
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(Fwd2DFp16_WmmaV3_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3",
expected_transfer_parameters,
"Default",
"Intrawave",
"v1",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
TEST(Fwd2DFp16_WmmaV3_GNHWC, Execution)
{
if(!ck_tile::get_device_name().starts_with("gfx11") &&
!ck_tile::get_device_name().starts_with("gfx12"))
{
// Note: WMMA kernel requires gfx11 or gfx12
GTEST_SKIP() << "unsupported architecture";
}
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 16,
.groups = 1,
.input_channels = 32,
.output_channels = 48,
.image =
{
.width = 56,
.height = 64,
},
.filter =
{
.width = 3,
.height = 5,
},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -632,6 +632,14 @@ using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
BlockGemm_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemmABK1_,
Transfer_<>,
ConvSpecializationFwd_,
BlockGemm_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,

View File

@@ -12,6 +12,7 @@
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
@@ -626,6 +627,118 @@ TEST(InstanceTraits, WmmaInstanceStringReturnsCorrectFormat)
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, WmmaV3InstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpec
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
ck::Sequence<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
ck::Sequence<1, 16, 1, 4>, // CDEBlockTransferClusterLengths
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1>; // BlkGemmPipelineVer
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all template parameters
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",64" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",32" // KPerBlock
",8" // AK1
",8" // BK1
",16" // MPerWmma
",16" // NPerWmma
",4" // MRepeat
",2" // NRepeat
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",1" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",true" // ABlockLdsExtraM
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",1" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",true" // BBlockLdsExtraN
",1" // CShuffleMRepeatPerShuffle
",1" // CShuffleNRepeatPerShuffle
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
",1" // CDEBlockTransferScalarPerVector_NPerBlock
",Intrawave" // BlkGemmPipeSched
",v1" // BlkGemmPipelineVer
",true" // UseThreadTileTransfer
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",1>"; // NumGroupsToMerge
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/conv_describe.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp>
namespace {
namespace ckr = ck_tile::reflect;
// Use the template helper to get a working instance configuration
using InstanceTuple = ck::tensor_operation::device::instance::
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Expected complete instance string based on the first instance from
// device_grouped_conv_fwd_wmma_cshufflev3_f16_instances
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",64" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",32" // KPerBlock
",8" // AK1
",8" // BK1
",16" // MPerWmma
",16" // NPerWmma
",4" // MRepeat
",2" // NRepeat
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",1" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",true" // ABlockLdsExtraM
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",1" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",true" // BBlockLdsExtraN
",1" // CShuffleMRepeatPerShuffle
",1" // CShuffleNRepeatPerShuffle
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
",1" // CDEBlockTransferScalarPerVector_NPerBlock
",Intrawave" // BlkGemmPipeSched
",v1" // BlkGemmPipelineVer
",true" // UseThreadTileTransfer
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",1>"; // NumGroupsToMerge
// Test describe() through base class pointer for WMMA V3 variant
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvWmmaV3)
{
using BaseClass = ck::tensor_operation::device::BaseOperator;
DeviceInstance device_instance;
BaseClass* base_ptr = &device_instance;
auto desc = base_ptr->describe();
ASSERT_NE(desc, nullptr);
EXPECT_EQ(desc->instance_string(), expected_str);
}
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvWmmaV3)
{
EXPECT_EQ(ckr::describe<DeviceInstance>().instance_string(), expected_str);
}
} // namespace

View File

@@ -344,6 +344,16 @@ constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_2x1_per_wave{.ak1
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1};
constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_4x2_per_wave{.ak1 = 8,
.bk1 = 8,
.m_per_wmma = 16,
.n_per_wmma = 16,
.m_wmma_per_wave = 4,
.n_wmma_per_wave = 2};
constexpr ThreadBlock ThreadBlock_64_64x64x32{.block_size = 64,
.tile_size = {.m = 64, .n = 64, .k = 32}};
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};

View File

@@ -409,6 +409,17 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
<< to_string(static_cast<WmmaGemmABK1_>(t)) << ","
<< to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>(
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t)