[CK_BUILDER] Instance traits for conv bwd weight algorithms (#3498)

Added instance traits for the following bwd weight conv algorithms

DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
DeviceGroupedConvBwdWeight_Wmma_CShuffle
DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle
DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffleV3
DeviceGroupedConvBwdWeight_DL
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
Added also unit tests for instance traits of those bwd weigth algorithms that are currently exposed by the narrow CK build for MIOpen.
---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-12-31 15:41:15 -08:00
committed by GitHub
parent f3e4d46faa
commit 6e8c401e33
25 changed files with 3206 additions and 2 deletions

View File

@@ -20,6 +20,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_dl.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -1227,6 +1232,24 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_dl.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device

View File

@@ -28,6 +28,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -1250,6 +1255,25 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3::Argument structure!");
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(
ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device

View File

@@ -26,6 +26,11 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -1207,6 +1212,24 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!");
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device

View File

@@ -30,6 +30,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -1571,6 +1576,25 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3::Argument structure!");
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(
ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device

View File

@@ -30,6 +30,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -2098,6 +2103,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device

View File

@@ -19,6 +19,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -865,6 +870,24 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device

View File

@@ -31,6 +31,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -1422,6 +1427,24 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3::Argument structure!");
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device

View File

@@ -29,6 +29,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -1548,6 +1553,24 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};
} // namespace device