Add instance traits for two more grouped forward convolutions (#3112)

This commit is contained in:
John Shumway
2025-10-29 08:04:13 -07:00
committed by GitHub
parent 121bf0e1f3
commit cafaeb6b7b
11 changed files with 1193 additions and 13 deletions

View File

@@ -28,6 +28,9 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
@@ -2063,6 +2066,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);

View File

@@ -24,6 +24,9 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
@@ -1220,6 +1223,20 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(
ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
#endif
};
} // namespace device