Concept bug fixes.

This commit is contained in:
Ville Pietilä
2025-12-22 09:23:47 -05:00
parent 5ee99d83d5
commit dacf82d652
5 changed files with 20 additions and 19 deletions

View File

@@ -157,36 +157,36 @@ concept SpecifiesTileThreadBlock = requires {
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseFwdXdlGemmDescriptor = requires {
{ T::ak1 } -> std::convertible_to<size_t>;
{ T::bk1 } -> std::convertible_to<size_t>;
{ T::xdl_params } -> GridwiseXdlGemmDescriptor;
concept GridwiseFwdXdlGemmDescriptor = requires (T t){
{ t.ak1 } -> std::convertible_to<size_t>;
{ t.bk1 } -> std::convertible_to<size_t>;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseBwdXdlGemmDescriptor = requires {
{ T::k0_per_block } -> std::convertible_to<size_t>;
{ T::k1 } -> std::convertible_to<size_t>;
{ T::xdl_params } -> GridwiseXdlGemmDescriptor;
concept GridwiseBwdXdlGemmDescriptor = requires (T t){
{ t.k0_per_block } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseFwdXdlGemm = requires {
{ T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
concept SpecifiesGridwiseFwdXdlGemm = requires (T t) {
{ t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseBwdXdlGemm = requires {
{ T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
concept SpecifiesGridwiseBwdXdlGemm = requires (T t) {
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise WMMA GEMM info.
template <typename T>
concept SpecifiesGridwiseWmmaGemm = requires {
{ T::gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
concept SpecifiesGridwiseWmmaGemm = requires (T t){
{ t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
};
// Concept to check if a struct specifies convolution input and output block transfer info.

View File

@@ -161,7 +161,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization()
{
constexpr auto specialization = ALGORITHM.bwd_specialization;
constexpr auto specialization = ALGORITHM.bwd_weight_specialization;
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
switch(specialization)
{

View File

@@ -25,7 +25,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::Transfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
static_assert(cku::SpecifiesGridwiseBwdXdlGemm<decltype(ALGORITHM)>, "Error");
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;

View File

@@ -211,7 +211,7 @@ struct ConvSpecializationFwd_
struct ConvSpecializationBwdWeight_
{
ConvSpecialization bwd_specialization;
ConvSpecialization bwd_weight_specialization;
};
struct Prefetch_
@@ -400,7 +400,7 @@ struct ConvAlgorithmTemplate : Components...
{
static_assert(std::is_base_of_v<ConvSpecializationBwdWeight_, ConvAlgorithmTemplate>);
auto result = *this;
result.bwd_specialization = bwd_spec;
result.bwd_weight_specialization = bwd_spec;
return result;
}

View File

@@ -278,7 +278,7 @@ template <>
inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwdWeight_ t)
{
std::ostringstream oss;
oss << to_string(t.bwd_specialization);
oss << to_string(t.bwd_weight_specialization);
return oss.str();
}