mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] Forward convolution builder improvements (#3179)
Proposed changes Improve the forward convolution builder implementation and addressed leftover feedback left from PR #3138. Main changes Refactored tests such that they reflect better the builder pattern. The templates and types for the convolution algorithm concepts are created via factory that facilitates programmatic creation of the device op instances. Moved tests into anonymous namespace. The convolution factory had lot of if-else constructs when CK Builder types were converted into CK library types. I had initially trouble in using static_assert in the default branch of switch as the static_assert was evaluated at compile time even for valid types. However, if we change the static_assert to throw "<error message>", it will result in a compile-time error only if the default branch is actually hit. This assumes that the function is consteval. Hence, changed all conversions in the convolution factory to use switch, which is more intuitive. Removed the explicit device op definition from convolution signature and the corresponding predicate file. The device ops are defined by the corresponding concepts. This allowed to remove lot of boilerplate code from the convolution factory. Adde inheritance and convolution algorithm specialization to handle device ops that are specialization of a more generic ones. The large tensor support is more naturally expressed by this pattern. Added support for the FP8 data type. * WIP: Builder for expected test results. * Improve ckb fwd conv instance tests. * clang-format * Change if-else statements into switch in conv factory. * Fix clang-formatting. * Removed unnecessary includes. * Added missing copyright. * Remove explicit device op flag from from convolution signature. * Add missing concept. * Fix build. * clang-format * Add test for building conv fwd FP8 instances. * Add missing header to instance traits. * Clean-up recently added instances. * Introduce inheritance and specialization. * Use builder to build conv algorithm templates and types. * clang-format * Fix conv description tests. --------- Co-authored-by: John Shumway <john.shumwayjr@gmail.com>
This commit is contained in:
@@ -95,7 +95,8 @@ concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlgorithm concept.
|
||||
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
|
||||
// concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
|
||||
|
||||
@@ -183,6 +184,12 @@ concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesLargeTensorSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
@@ -204,9 +211,9 @@ concept DlThreadClusterDescriptor = requires(T t) {
|
||||
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_M0_M1_K1 format
|
||||
// Concept for DL block transfer
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
|
||||
concept DlBlockTransferDescriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
@@ -216,21 +223,9 @@ concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_N0_N1_K1 format
|
||||
// Concept for DL epilogue
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL C thread transfer
|
||||
template <typename T>
|
||||
concept DlCThreadTransferDescriptor = requires(T t) {
|
||||
concept DlEpilogueDescriptor = requires(T t) {
|
||||
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
|
||||
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
@@ -239,31 +234,63 @@ concept DlCThreadTransferDescriptor = requires(T t) {
|
||||
// Concept to check if algorithm specifies DL thread config
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadConfig = requires {
|
||||
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
|
||||
{ T::thread_config } -> DlThreadConfigDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread cluster
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadCluster = requires {
|
||||
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
|
||||
{ T::thread_cluster } -> DlThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL A block transfer
|
||||
// Concept to check if algorithm specifies DL block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferA = requires {
|
||||
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL B block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferB = requires {
|
||||
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
|
||||
concept SpecifiesDlBlockTransfer = requires {
|
||||
{ T::block_transfer_a } -> DlBlockTransferDescriptor;
|
||||
{ T::block_transfer_b } -> DlBlockTransferDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlCThreadTransfer = requires {
|
||||
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
|
||||
concept SpecifiesDlEpilogue = requires {
|
||||
{ T::epilogue_c } -> DlEpilogueDescriptor;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* Concepts for the different device ops */
|
||||
/******************************************** */
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(T::base_algorithm)> &&
|
||||
SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -256,6 +256,19 @@ struct ConvTensorTypes<DataType::I8>
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP8>
|
||||
{
|
||||
using ADataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::f8_t;
|
||||
};
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
@@ -302,49 +315,31 @@ struct BlockGemmSpec
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockGemmSpec SetBlockGemm()
|
||||
consteval BlockGemmSpec SetBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
|
||||
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
|
||||
switch(BG.scheduler)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
}
|
||||
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break;
|
||||
case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break;
|
||||
case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave.";
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
|
||||
if constexpr(BG.pipeline_version == PipelineVersion::V1)
|
||||
switch(BG.pipeline_version)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break;
|
||||
case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break;
|
||||
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
|
||||
default: throw "Unknown PipelineVersion";
|
||||
}
|
||||
|
||||
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
|
||||
@@ -453,18 +448,13 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::LoopScheduler SetLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
|
||||
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
|
||||
using ck_loop_sched = ck::LoopScheduler;
|
||||
switch(loop_scheduler)
|
||||
{
|
||||
return ck::LoopScheduler::Default;
|
||||
}
|
||||
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
return ck::LoopScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
case PipelineScheduler::DEFAULT: return ck_loop_sched::Default;
|
||||
case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave;
|
||||
case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE.";
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -472,29 +462,16 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
if constexpr(pipeline_version == PipelineVersion::V1)
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
switch(pipeline_version)
|
||||
{
|
||||
return ck::PipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::PipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
static_assert(false, "V3 is used only for stream-K.");
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::PipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
|
||||
{
|
||||
return ck::PipelineVersion::weight_only;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
case PipelineVersion::V1: return ck_pipeline::v1;
|
||||
case PipelineVersion::V2: return ck_pipeline::v2;
|
||||
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
|
||||
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,74 +479,27 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
|
||||
{
|
||||
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
|
||||
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == GemmSpecialization::Default)
|
||||
switch(gemm_spec)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::KPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::KPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::OPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::OPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::KOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::KOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MKOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NKOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown GemmSpecialization");
|
||||
case GemmSpecialization::Default: return ck_gemm_spec::Default;
|
||||
case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding;
|
||||
case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding;
|
||||
case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding;
|
||||
case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding;
|
||||
case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding;
|
||||
case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding;
|
||||
case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding;
|
||||
case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding;
|
||||
case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding;
|
||||
case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding;
|
||||
case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding;
|
||||
case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding;
|
||||
case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding;
|
||||
case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding;
|
||||
case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding;
|
||||
default: throw "Unknown GemmSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,30 +507,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == PipelineVersion::V1)
|
||||
using ck_pipeline = ck::BlockGemmPipelineVersion;
|
||||
switch(version)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
case PipelineVersion::V1: return ck_pipeline::v1;
|
||||
case PipelineVersion::V2: return ck_pipeline::v2;
|
||||
case PipelineVersion::V3: return ck_pipeline::v3;
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: return ck_pipeline::v5;
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -608,26 +523,14 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
|
||||
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown ConvFwdSpecialization");
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -639,7 +542,12 @@ namespace ck_tile::builder {
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
struct ConvFactory;
|
||||
struct ConvFactory
|
||||
{
|
||||
// This will trigger if a specialization for the given convolution direction is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(false, "Unsupported device operation.");
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
|
||||
// of a grouped forward convolution kernel.
|
||||
@@ -647,7 +555,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -658,26 +566,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesBlockGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block gemm pipeline.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==
|
||||
ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
|
||||
"A and B block transfers must both be direct load or not.");
|
||||
@@ -769,7 +657,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -780,31 +668,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
static_assert(SpecifiesNumGroupsToMerge<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of groups to merge.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
@@ -891,7 +754,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -902,27 +765,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseWmmaGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
@@ -1008,7 +850,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
|
||||
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -1019,24 +861,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
|
||||
"DL algorithm must specify thread config.");
|
||||
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
|
||||
"DL algorithm must specify thread cluster.");
|
||||
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
|
||||
"DL algorithm must specify A block transfer.");
|
||||
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
|
||||
"DL algorithm must specify B block transfer.");
|
||||
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
|
||||
"DL algorithm must specify C thread transfer.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
@@ -1045,7 +869,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
@@ -1053,12 +877,12 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.block_transfer_a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
@@ -1074,7 +898,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.block_transfer_b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
@@ -1090,7 +914,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.epilogue_c;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
@@ -1148,7 +972,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -1159,45 +983,24 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
factory_internal::SetFwdConvSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
factory_internal::SetGemmSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<BASE_ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<BASE_ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvABlockTransfer<BASE_ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBBlockTransfer<BASE_ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
@@ -1227,7 +1030,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BASE_ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
|
||||
@@ -21,7 +21,6 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_predicates.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
@@ -41,9 +40,6 @@ template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
template <typename T>
|
||||
concept ConvDeviceOp = std::same_as<std::remove_cvref_t<T>, GroupConvDeviceOp>;
|
||||
|
||||
template <typename T>
|
||||
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
|
||||
|
||||
@@ -55,7 +51,6 @@ concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.layout } -> ConvLayout;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
{ t.device_operation } -> ConvDeviceOp;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
@@ -63,7 +58,18 @@ template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
requires IsValidConvDeviceOp<Sig>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/**********************************************
|
||||
* Conv Direction Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
/**********************************************
|
||||
* Conv Fwd Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
|
||||
|
||||
// Generic predicate to check if signature uses any forward convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsForward =
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Conv Bwd Weight Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
|
||||
|
||||
// Generic predicate to check if signature uses any backward weight convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsBackwardWeight =
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Conv Bwd Data Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Generic predicate to check if signature uses any backward data convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsBackwardData =
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Generic Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Generic predicate to check if signature uses any device operation.
|
||||
template <auto Sig>
|
||||
concept IsValidConvDeviceOp = ConvDeviceOpIsForward<Sig> || ConvDeviceOpIsBackwardData<Sig> ||
|
||||
ConvDeviceOpIsBackwardWeight<Sig>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -1,22 +0,0 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Enumeration for CK Device Operation types.
|
||||
// This allows the builder to select which device operation template to instantiate
|
||||
// based on the user's requirements.
|
||||
enum class DeviceOpType
|
||||
{
|
||||
// Forward Convolution - Non-grouped
|
||||
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
|
||||
|
||||
// Forward Convolution - Grouped
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -14,6 +14,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
|
||||
@@ -74,52 +74,6 @@ enum class ConvDirection
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Forward convolution device operations.
|
||||
enum class FwdGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
};
|
||||
|
||||
// Backward data convolution device operations.
|
||||
enum class BwdDataGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdDataMultipleD,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
};
|
||||
|
||||
// Backward weight convolution device operations.
|
||||
enum class BwdWeightGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdWeight,
|
||||
DeviceGroupedConvBwdWeight_Dl,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeightMultipleD,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
|
||||
};
|
||||
|
||||
// Structural type for device operation
|
||||
struct GroupConvDeviceOp
|
||||
{
|
||||
union
|
||||
{
|
||||
FwdGroupConvDeviceOperation _fwd;
|
||||
BwdDataGroupConvDeviceOperation _bwd_data;
|
||||
BwdWeightGroupConvDeviceOperation _bwd_weight;
|
||||
};
|
||||
|
||||
constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {}
|
||||
constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {}
|
||||
constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {}
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
@@ -219,6 +173,11 @@ enum class PipelineScheduler
|
||||
INTERWAVE
|
||||
};
|
||||
|
||||
enum class ConvAlgorithmSpecialization
|
||||
{
|
||||
LARGE_TENSOR
|
||||
};
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
{
|
||||
@@ -286,61 +245,6 @@ inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum FwdGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
|
||||
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
|
||||
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdDataGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdWeightGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
|
||||
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
|
||||
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
|
||||
@@ -39,7 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
@@ -47,7 +48,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
)
|
||||
|
||||
function(add_ck_factory_test test_name)
|
||||
add_ck_builder_test(${test_name} ${ARGN})
|
||||
@@ -66,10 +68,10 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_traits
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_description
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
# Function to add all test_ckb targets to a list
|
||||
|
||||
@@ -1,34 +1,39 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 1D BF16 (channels-first) with Pipeline V2 and FILTER_1X1_STRIDE1_PAD0 specialization and SCALE
|
||||
// elementwise op
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::SCALE,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::SCALE};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V2,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v2"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,31 +1,35 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 1D FP16 (channels-last) with DEFAULT specialization
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "64, 64, 32, 32", "Default"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,31 +1,35 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 1D I8 (channels-last) with and DEFAULT specialization
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
|
||||
.data_type = DataType::I8,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
|
||||
.data_type = DataType::I8,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x32x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", "128, 64, 64, 64", "Default"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,54 +1,63 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"Filter3x3",
|
||||
"BlkGemmPipelineVersion: v5"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,69 +1,59 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB)
|
||||
.with_dl_epilogue(DlEpilogueC);
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Default"});
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB)
|
||||
.with_dl_epilogue(DlEpilogueC);
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
35
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp
Normal file
35
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
// 2D FP8 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP8,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1_fp8)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "256, 256, 128, 32", "Default"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -1,53 +1,64 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"256, 256, 128, 32",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
TEST(
|
||||
FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"128, 128, 128, 32",
|
||||
"Filter1x1Pad0"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -117,103 +117,6 @@ struct BlockTransferABC
|
||||
AccessOrder src_access_order_b;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
BlockGemm block_gemm;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesGemmSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumGroupsToMerge<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseWmmaGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
|
||||
// DL-specific descriptors
|
||||
struct DlThreadConfig
|
||||
{
|
||||
@@ -227,12 +130,12 @@ static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
|
||||
|
||||
struct DlThreadCluster
|
||||
{
|
||||
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
|
||||
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
|
||||
std::array<size_t, 2> m1_xs;
|
||||
std::array<size_t, 2> n1_xs;
|
||||
};
|
||||
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
|
||||
|
||||
struct DlBlockTransferK0M0M1K1
|
||||
struct DlBlockTransfer
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
@@ -242,56 +145,212 @@ struct DlBlockTransferK0M0M1K1
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
|
||||
static_assert(ckb::DlBlockTransferDescriptor<DlBlockTransfer>);
|
||||
|
||||
struct DlBlockTransferK0N0N1K1
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
|
||||
|
||||
struct DlCThreadTransfer
|
||||
struct DlEpilogue
|
||||
{
|
||||
std::array<size_t, 6> src_dst_access_order;
|
||||
size_t src_dst_vector_dim;
|
||||
size_t dst_scalar_per_vector;
|
||||
};
|
||||
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
|
||||
static_assert(ckb::DlEpilogueDescriptor<DlEpilogue>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
// Factory types
|
||||
|
||||
struct ThreadBlock_
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct XdlGemm_
|
||||
{
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemm_
|
||||
{
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BlockTransfer_
|
||||
{
|
||||
BlockTransferABC block_transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecialization_
|
||||
{
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
DlThreadConfig dl_thread_config;
|
||||
DlThreadCluster dl_thread_cluster;
|
||||
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
|
||||
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
|
||||
DlCThreadTransfer dl_c_thread_transfer;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
|
||||
struct Prefetch_
|
||||
{
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
|
||||
struct BlockGemm_
|
||||
{
|
||||
BlockGemm block_gemm;
|
||||
};
|
||||
|
||||
struct DlThreadConfig_
|
||||
{
|
||||
DlThreadConfig thread_config;
|
||||
};
|
||||
|
||||
struct DlThreadCluster_
|
||||
{
|
||||
DlThreadCluster thread_cluster;
|
||||
};
|
||||
|
||||
struct DlBlockTransfer_
|
||||
{
|
||||
DlBlockTransfer block_transfer_a;
|
||||
DlBlockTransfer block_transfer_b;
|
||||
};
|
||||
|
||||
struct DlEpilogue_
|
||||
{
|
||||
DlEpilogue epilogue_c;
|
||||
};
|
||||
|
||||
// Specialization wrapper for large tensor support
|
||||
template <typename BaseAlgorithm>
|
||||
struct LargeTensorWrapper
|
||||
{
|
||||
BaseAlgorithm base_algorithm;
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
// Factory
|
||||
|
||||
template <typename... Components>
|
||||
struct ConvAlgorithmTemplate : Components...
|
||||
{
|
||||
|
||||
template <typename TB>
|
||||
constexpr auto with_thread_block(const TB& tb) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ThreadBlock_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_block = tb;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
constexpr auto with_gemm_config(const GemmConfig& gemm) const
|
||||
{
|
||||
auto result = *this;
|
||||
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<WmmaGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BT>
|
||||
constexpr auto with_block_transfer(const BT& bt) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<BlockTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_transfer = bt;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_specializations(ConvFwdSpecialization fwd_spec,
|
||||
GemmSpecialization gemm_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecialization_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.fwd_specialization = fwd_spec;
|
||||
result.gemm_specialization = gemm_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
|
||||
size_t groups_to_merge,
|
||||
PipelineScheduler scheduler) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.num_gemm_k_prefetch_stages = k_prefetch_stages;
|
||||
result.num_groups_to_merge = groups_to_merge;
|
||||
result.loop_scheduler = scheduler;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_block_gemm(const BG& bg) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<BlockGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_gemm = bg;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename TC>
|
||||
constexpr auto with_dl_thread_config(const TC& tc) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlThreadConfig_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_config = tc;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename TCl>
|
||||
constexpr auto with_dl_thread_cluster(const TCl& tcl) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlThreadCluster_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_cluster = tcl;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BTA, typename BTB>
|
||||
constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlBlockTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_transfer_a = bta;
|
||||
result.block_transfer_b = btb;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_dl_epilogue(const DlEpilogue& epi) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlEpilogue_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.epilogue_c = epi;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Algorithm types
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
ConvSpecialization_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlBlockTransfer_,
|
||||
DlEpilogue_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -17,7 +17,6 @@ struct ConvSignature
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
GroupConvDeviceOp device_operation;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ struct ConvSignature
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
|
||||
@@ -1,383 +0,0 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test implementation
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
PipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == PipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == PipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == PipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == PipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 2,
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
{
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK()
|
||||
{
|
||||
// DL thread configuration
|
||||
constexpr DlThreadConfig DlThreadCfg{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
// DL thread cluster
|
||||
constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
// DL A block transfer - K0_M0_M1_K1 format
|
||||
constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL B block transfer - K0_N0_N1_K1 format
|
||||
constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL C thread transfer
|
||||
constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.dl_thread_config = DlThreadCfg,
|
||||
.dl_thread_cluster = DlCluster,
|
||||
.dl_block_transfer_a = DlBlockTransferA,
|
||||
.dl_block_transfer_b = DlBlockTransferB,
|
||||
.dl_c_thread_transfer = DlCTransfer};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// Note: Large_Tensor has identical parameters to regular XDL CShuffle
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
// Large_Tensor uses the same descriptor as regular XDL CShuffle
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 1,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(
|
||||
kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
184
experimental/builder/test/utils/ckb_conv_test_configs.hpp
Normal file
184
experimental/builder/test/utils/ckb_conv_test_configs.hpp
Normal file
@@ -0,0 +1,184 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
constexpr DlEpilogue DlEpilogueC{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x16x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x32x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
31
experimental/builder/test/utils/ckb_conv_test_utils.hpp
Normal file
31
experimental/builder/test/utils/ckb_conv_test_utils.hpp
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
using namespace test;
|
||||
|
||||
// Common test implementation
|
||||
template <typename Builder>
|
||||
constexpr void run_test(const std::vector<std::string>& kernel_instance_components)
|
||||
{
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
|
||||
for(const auto& component : kernel_instance_components)
|
||||
{
|
||||
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
Reference in New Issue
Block a user