[CK_BUILDER] Forward convolution builder improvements (#3179)

Proposed changes
Improve the forward convolution builder implementation and addressed leftover feedback left from PR #3138. Main changes

Refactored tests such that they reflect better the builder pattern. The templates and types for the convolution algorithm concepts are created via factory that facilitates programmatic creation of the device op instances.
Moved tests into anonymous namespace.
The convolution factory had lot of if-else constructs when CK Builder types were converted into CK library types. I had initially trouble in using static_assert in the default branch of switch as the static_assert was evaluated at compile time even for valid types. However, if we change the static_assert to throw "<error message>", it will result in a compile-time error only if the default branch is actually hit. This assumes that the function is consteval. Hence, changed all conversions in the convolution factory to use switch, which is more intuitive.
Removed the explicit device op definition from convolution signature and the corresponding predicate file. The device ops are defined by the corresponding concepts. This allowed to remove lot of boilerplate code from the convolution factory.
Adde inheritance and convolution algorithm specialization to handle device ops that are specialization of a more generic ones. The large tensor support is more naturally expressed by this pattern.
Added support for the FP8 data type.

* WIP: Builder for expected test results.

* Improve ckb fwd conv instance tests.

* clang-format

* Change if-else statements into switch in conv factory.

* Fix clang-formatting.

* Removed unnecessary includes.

* Added missing copyright.

* Remove explicit device op flag from from convolution signature.

* Add missing concept.

* Fix build.

* clang-format

* Add test for building conv fwd FP8 instances.

* Add missing header to instance traits.

* Clean-up recently added instances.

* Introduce inheritance and specialization.

* Use builder to build conv algorithm templates and types.

* clang-format

* Fix conv description tests.

---------

Co-authored-by: John Shumway <john.shumwayjr@gmail.com>
This commit is contained in:
Ville Pietilä
2025-11-13 18:47:25 +02:00
committed by GitHub
parent ca2ee0eb8a
commit 7d57bc169f
26 changed files with 946 additions and 1439 deletions

View File

@@ -117,103 +117,6 @@ struct BlockTransferABC
AccessOrder src_access_order_b;
};
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
ThreadBlock thread_block;
GridwiseXdlGemm gridwise_gemm;
BlockTransferABC block_transfer;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
BlockGemm block_gemm;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesGemmSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
ThreadBlock thread_block;
GridwiseXdlGemm gridwise_gemm;
BlockTransferABC block_transfer;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesNumGroupsToMerge<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
{
ThreadBlock thread_block;
GridwiseWmmaGemm gridwise_gemm;
BlockTransferABC block_transfer;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesGridwiseWmmaGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
// DL-specific descriptors
struct DlThreadConfig
{
@@ -227,12 +130,12 @@ static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
struct DlThreadCluster
{
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
std::array<size_t, 2> m1_xs;
std::array<size_t, 2> n1_xs;
};
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
struct DlBlockTransferK0M0M1K1
struct DlBlockTransfer
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
@@ -242,56 +145,212 @@ struct DlBlockTransferK0M0M1K1
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
static_assert(ckb::DlBlockTransferDescriptor<DlBlockTransfer>);
struct DlBlockTransferK0N0N1K1
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
struct DlCThreadTransfer
struct DlEpilogue
{
std::array<size_t, 6> src_dst_access_order;
size_t src_dst_vector_dim;
size_t dst_scalar_per_vector;
};
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
static_assert(ckb::DlEpilogueDescriptor<DlEpilogue>);
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// Factory types
struct ThreadBlock_
{
ThreadBlock thread_block;
};
struct XdlGemm_
{
GridwiseXdlGemm gridwise_gemm;
};
struct WmmaGemm_
{
GridwiseWmmaGemm gridwise_gemm;
};
struct BlockTransfer_
{
BlockTransferABC block_transfer;
};
struct ConvSpecialization_
{
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
DlThreadConfig dl_thread_config;
DlThreadCluster dl_thread_cluster;
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
DlCThreadTransfer dl_c_thread_transfer;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
struct Prefetch_
{
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
PipelineScheduler loop_scheduler;
};
struct BlockGemm_
{
BlockGemm block_gemm;
};
struct DlThreadConfig_
{
DlThreadConfig thread_config;
};
struct DlThreadCluster_
{
DlThreadCluster thread_cluster;
};
struct DlBlockTransfer_
{
DlBlockTransfer block_transfer_a;
DlBlockTransfer block_transfer_b;
};
struct DlEpilogue_
{
DlEpilogue epilogue_c;
};
// Specialization wrapper for large tensor support
template <typename BaseAlgorithm>
struct LargeTensorWrapper
{
BaseAlgorithm base_algorithm;
static constexpr ConvAlgorithmSpecialization specialization =
ConvAlgorithmSpecialization::LARGE_TENSOR;
};
// Factory
template <typename... Components>
struct ConvAlgorithmTemplate : Components...
{
template <typename TB>
constexpr auto with_thread_block(const TB& tb) const
{
static_assert(std::is_base_of_v<ThreadBlock_, ConvAlgorithmTemplate>);
auto result = *this;
result.thread_block = tb;
return result;
}
template <typename GemmConfig>
constexpr auto with_gemm_config(const GemmConfig& gemm) const
{
auto result = *this;
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
{
result.gridwise_gemm = gemm;
}
else if constexpr(std::is_base_of_v<WmmaGemm_, ConvAlgorithmTemplate>)
{
result.gridwise_gemm = gemm;
}
return result;
}
template <typename BT>
constexpr auto with_block_transfer(const BT& bt) const
{
static_assert(std::is_base_of_v<BlockTransfer_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_transfer = bt;
return result;
}
constexpr auto with_specializations(ConvFwdSpecialization fwd_spec,
GemmSpecialization gemm_spec) const
{
static_assert(std::is_base_of_v<ConvSpecialization_, ConvAlgorithmTemplate>);
auto result = *this;
result.fwd_specialization = fwd_spec;
result.gemm_specialization = gemm_spec;
return result;
}
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
size_t groups_to_merge,
PipelineScheduler scheduler) const
{
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
auto result = *this;
result.num_gemm_k_prefetch_stages = k_prefetch_stages;
result.num_groups_to_merge = groups_to_merge;
result.loop_scheduler = scheduler;
return result;
}
template <typename BG>
constexpr auto with_block_gemm(const BG& bg) const
{
static_assert(std::is_base_of_v<BlockGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_gemm = bg;
return result;
}
template <typename TC>
constexpr auto with_dl_thread_config(const TC& tc) const
{
static_assert(std::is_base_of_v<DlThreadConfig_, ConvAlgorithmTemplate>);
auto result = *this;
result.thread_config = tc;
return result;
}
template <typename TCl>
constexpr auto with_dl_thread_cluster(const TCl& tcl) const
{
static_assert(std::is_base_of_v<DlThreadCluster_, ConvAlgorithmTemplate>);
auto result = *this;
result.thread_cluster = tcl;
return result;
}
template <typename BTA, typename BTB>
constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const
{
static_assert(std::is_base_of_v<DlBlockTransfer_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_transfer_a = bta;
result.block_transfer_b = btb;
return result;
}
constexpr auto with_dl_epilogue(const DlEpilogue& epi) const
{
static_assert(std::is_base_of_v<DlEpilogue_, ConvAlgorithmTemplate>);
auto result = *this;
result.epilogue_c = epi;
return result;
}
};
// Algorithm types
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, BlockGemm_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvAlgorithmTemplate<ThreadBlock_,
ConvSpecialization_,
DlThreadConfig_,
DlThreadCluster_,
DlBlockTransfer_,
DlEpilogue_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
} // namespace ck_tile::builder::test

View File

@@ -17,7 +17,6 @@ struct ConvSignature
GroupConvLayout layout;
DataType data_type;
ElementwiseOperation elementwise_operation;
GroupConvDeviceOp device_operation;
};
static_assert(ConvSignatureDescriptor<ConvSignature>);