mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '7d57bc169f8206f06bc516a7f930f388def32347' into develop
This commit is contained in:
@@ -95,7 +95,8 @@ concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlgorithm concept.
|
||||
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
|
||||
// concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
|
||||
|
||||
@@ -183,6 +184,12 @@ concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesLargeTensorSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
@@ -204,9 +211,9 @@ concept DlThreadClusterDescriptor = requires(T t) {
|
||||
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_M0_M1_K1 format
|
||||
// Concept for DL block transfer
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
|
||||
concept DlBlockTransferDescriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
@@ -216,21 +223,9 @@ concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_N0_N1_K1 format
|
||||
// Concept for DL epilogue
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL C thread transfer
|
||||
template <typename T>
|
||||
concept DlCThreadTransferDescriptor = requires(T t) {
|
||||
concept DlEpilogueDescriptor = requires(T t) {
|
||||
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
|
||||
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
@@ -239,31 +234,63 @@ concept DlCThreadTransferDescriptor = requires(T t) {
|
||||
// Concept to check if algorithm specifies DL thread config
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadConfig = requires {
|
||||
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
|
||||
{ T::thread_config } -> DlThreadConfigDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread cluster
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadCluster = requires {
|
||||
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
|
||||
{ T::thread_cluster } -> DlThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL A block transfer
|
||||
// Concept to check if algorithm specifies DL block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferA = requires {
|
||||
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL B block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferB = requires {
|
||||
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
|
||||
concept SpecifiesDlBlockTransfer = requires {
|
||||
{ T::block_transfer_a } -> DlBlockTransferDescriptor;
|
||||
{ T::block_transfer_b } -> DlBlockTransferDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlCThreadTransfer = requires {
|
||||
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
|
||||
concept SpecifiesDlEpilogue = requires {
|
||||
{ T::epilogue_c } -> DlEpilogueDescriptor;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* Concepts for the different device ops */
|
||||
/******************************************** */
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConcSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
template <typename T>
|
||||
concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(T::base_algorithm)> &&
|
||||
SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -256,6 +256,19 @@ struct ConvTensorTypes<DataType::I8>
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP8>
|
||||
{
|
||||
using ADataType = ck::f8_t;
|
||||
using AComputeType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using BComputeType = ck::f8_t;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::f8_t;
|
||||
};
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
@@ -302,49 +315,31 @@ struct BlockGemmSpec
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockGemmSpec SetBlockGemm()
|
||||
consteval BlockGemmSpec SetBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
|
||||
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
|
||||
switch(BG.scheduler)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
}
|
||||
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break;
|
||||
case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break;
|
||||
case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave.";
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
|
||||
if constexpr(BG.pipeline_version == PipelineVersion::V1)
|
||||
switch(BG.pipeline_version)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break;
|
||||
case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break;
|
||||
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
|
||||
default: throw "Unknown PipelineVersion";
|
||||
}
|
||||
|
||||
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
|
||||
@@ -453,18 +448,13 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::LoopScheduler SetLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
|
||||
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
|
||||
using ck_loop_sched = ck::LoopScheduler;
|
||||
switch(loop_scheduler)
|
||||
{
|
||||
return ck::LoopScheduler::Default;
|
||||
}
|
||||
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
return ck::LoopScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
case PipelineScheduler::DEFAULT: return ck_loop_sched::Default;
|
||||
case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave;
|
||||
case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE.";
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -472,29 +462,16 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
if constexpr(pipeline_version == PipelineVersion::V1)
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
switch(pipeline_version)
|
||||
{
|
||||
return ck::PipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::PipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
static_assert(false, "V3 is used only for stream-K.");
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::PipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
|
||||
{
|
||||
return ck::PipelineVersion::weight_only;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
case PipelineVersion::V1: return ck_pipeline::v1;
|
||||
case PipelineVersion::V2: return ck_pipeline::v2;
|
||||
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
|
||||
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,74 +479,27 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
|
||||
{
|
||||
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
|
||||
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == GemmSpecialization::Default)
|
||||
switch(gemm_spec)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::KPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::KPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::OPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::OPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::KOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::KOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MKOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NKOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown GemmSpecialization");
|
||||
case GemmSpecialization::Default: return ck_gemm_spec::Default;
|
||||
case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding;
|
||||
case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding;
|
||||
case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding;
|
||||
case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding;
|
||||
case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding;
|
||||
case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding;
|
||||
case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding;
|
||||
case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding;
|
||||
case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding;
|
||||
case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding;
|
||||
case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding;
|
||||
case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding;
|
||||
case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding;
|
||||
case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding;
|
||||
case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding;
|
||||
default: throw "Unknown GemmSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,30 +507,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == PipelineVersion::V1)
|
||||
using ck_pipeline = ck::BlockGemmPipelineVersion;
|
||||
switch(version)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == PipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
case PipelineVersion::V1: return ck_pipeline::v1;
|
||||
case PipelineVersion::V2: return ck_pipeline::v2;
|
||||
case PipelineVersion::V3: return ck_pipeline::v3;
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: return ck_pipeline::v5;
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -608,26 +523,14 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
|
||||
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown ConvFwdSpecialization");
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -639,7 +542,12 @@ namespace ck_tile::builder {
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
struct ConvFactory;
|
||||
struct ConvFactory
|
||||
{
|
||||
// This will trigger if a specialization for the given convolution direction is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(false, "Unsupported device operation.");
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
|
||||
// of a grouped forward convolution kernel.
|
||||
@@ -647,7 +555,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -658,26 +566,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesBlockGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block gemm pipeline.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==
|
||||
ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
|
||||
"A and B block transfers must both be direct load or not.");
|
||||
@@ -769,7 +657,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -780,31 +668,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
static_assert(SpecifiesNumGroupsToMerge<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of groups to merge.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
@@ -891,7 +754,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -902,27 +765,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseWmmaGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
@@ -1008,7 +850,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
|
||||
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -1019,24 +861,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
|
||||
"DL algorithm must specify thread config.");
|
||||
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
|
||||
"DL algorithm must specify thread cluster.");
|
||||
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
|
||||
"DL algorithm must specify A block transfer.");
|
||||
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
|
||||
"DL algorithm must specify B block transfer.");
|
||||
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
|
||||
"DL algorithm must specify C thread transfer.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
@@ -1045,7 +869,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
@@ -1053,12 +877,12 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.block_transfer_a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
@@ -1074,7 +898,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.block_transfer_b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
@@ -1090,7 +914,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.epilogue_c;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
@@ -1148,7 +972,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<decltype(ALGORITHM)>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -1159,45 +983,24 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
factory_internal::SetFwdConvSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
factory_internal::SetGemmSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<BASE_ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<BASE_ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvABlockTransfer<BASE_ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBBlockTransfer<BASE_ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
@@ -1227,7 +1030,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BASE_ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
|
||||
@@ -21,7 +21,6 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_predicates.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
@@ -41,9 +40,6 @@ template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
template <typename T>
|
||||
concept ConvDeviceOp = std::same_as<std::remove_cvref_t<T>, GroupConvDeviceOp>;
|
||||
|
||||
template <typename T>
|
||||
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
|
||||
|
||||
@@ -55,7 +51,6 @@ concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.layout } -> ConvLayout;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
{ t.device_operation } -> ConvDeviceOp;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
@@ -63,7 +58,18 @@ template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
requires IsValidConvDeviceOp<Sig>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/**********************************************
|
||||
* Conv Direction Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
/**********************************************
|
||||
* Conv Fwd Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
|
||||
|
||||
// Generic predicate to check if signature uses any forward convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsForward =
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Conv Bwd Weight Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
|
||||
|
||||
// Generic predicate to check if signature uses any backward weight convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsBackwardWeight =
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Conv Bwd Data Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Generic predicate to check if signature uses any backward data convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsBackwardData =
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Generic Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Generic predicate to check if signature uses any device operation.
|
||||
template <auto Sig>
|
||||
concept IsValidConvDeviceOp = ConvDeviceOpIsForward<Sig> || ConvDeviceOpIsBackwardData<Sig> ||
|
||||
ConvDeviceOpIsBackwardWeight<Sig>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -1,22 +0,0 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Enumeration for CK Device Operation types.
|
||||
// This allows the builder to select which device operation template to instantiate
|
||||
// based on the user's requirements.
|
||||
enum class DeviceOpType
|
||||
{
|
||||
// Forward Convolution - Non-grouped
|
||||
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
|
||||
|
||||
// Forward Convolution - Grouped
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -14,6 +14,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
|
||||
@@ -74,52 +74,6 @@ enum class ConvDirection
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Forward convolution device operations.
|
||||
enum class FwdGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
};
|
||||
|
||||
// Backward data convolution device operations.
|
||||
enum class BwdDataGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdDataMultipleD,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
};
|
||||
|
||||
// Backward weight convolution device operations.
|
||||
enum class BwdWeightGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdWeight,
|
||||
DeviceGroupedConvBwdWeight_Dl,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeightMultipleD,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
|
||||
};
|
||||
|
||||
// Structural type for device operation
|
||||
struct GroupConvDeviceOp
|
||||
{
|
||||
union
|
||||
{
|
||||
FwdGroupConvDeviceOperation _fwd;
|
||||
BwdDataGroupConvDeviceOperation _bwd_data;
|
||||
BwdWeightGroupConvDeviceOperation _bwd_weight;
|
||||
};
|
||||
|
||||
constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {}
|
||||
constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {}
|
||||
constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {}
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
@@ -219,6 +173,11 @@ enum class PipelineScheduler
|
||||
INTERWAVE
|
||||
};
|
||||
|
||||
enum class ConvAlgorithmSpecialization
|
||||
{
|
||||
LARGE_TENSOR
|
||||
};
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
{
|
||||
@@ -286,61 +245,6 @@ inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum FwdGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
|
||||
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
|
||||
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdDataGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdWeightGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
|
||||
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
|
||||
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
|
||||
@@ -39,7 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
@@ -47,7 +48,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
)
|
||||
|
||||
function(add_ck_factory_test test_name)
|
||||
add_ck_builder_test(${test_name} ${ARGN})
|
||||
@@ -66,10 +68,10 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_traits
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_description
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
# Function to add all test_ckb targets to a list
|
||||
|
||||
@@ -1,34 +1,39 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 1D BF16 (channels-first) with Pipeline V2 and FILTER_1X1_STRIDE1_PAD0 specialization and SCALE
|
||||
// elementwise op
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::SCALE,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::SCALE};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V2,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v2"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,31 +1,35 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 1D FP16 (channels-last) with DEFAULT specialization
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "64, 64, 32, 32", "Default"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,31 +1,35 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 1D I8 (channels-last) with and DEFAULT specialization
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
|
||||
.data_type = DataType::I8,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
|
||||
.data_type = DataType::I8,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x32x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", "128, 64, 64, 64", "Default"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,54 +1,63 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"Filter3x3",
|
||||
"BlkGemmPipelineVersion: v5"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,69 +1,59 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB)
|
||||
.with_dl_epilogue(DlEpilogueC);
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Default"});
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB)
|
||||
.with_dl_epilogue(DlEpilogueC);
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
35
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp
Normal file
35
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
// 2D FP8 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP8,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1_fp8)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>(
|
||||
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "256, 256, 128, 32", "Default"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -1,53 +1,64 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"256, 256, 128, 32",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
TEST(
|
||||
FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"128, 128, 128, 32",
|
||||
"Filter1x1Pad0"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation =
|
||||
ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_block_transfer(FwdBlockTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
} // namespace
|
||||
|
||||
@@ -117,103 +117,6 @@ struct BlockTransferABC
|
||||
AccessOrder src_access_order_b;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
BlockGemm block_gemm;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesGemmSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumGroupsToMerge<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseWmmaGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
|
||||
// DL-specific descriptors
|
||||
struct DlThreadConfig
|
||||
{
|
||||
@@ -227,12 +130,12 @@ static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
|
||||
|
||||
struct DlThreadCluster
|
||||
{
|
||||
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
|
||||
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
|
||||
std::array<size_t, 2> m1_xs;
|
||||
std::array<size_t, 2> n1_xs;
|
||||
};
|
||||
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
|
||||
|
||||
struct DlBlockTransferK0M0M1K1
|
||||
struct DlBlockTransfer
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
@@ -242,56 +145,212 @@ struct DlBlockTransferK0M0M1K1
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
|
||||
static_assert(ckb::DlBlockTransferDescriptor<DlBlockTransfer>);
|
||||
|
||||
struct DlBlockTransferK0N0N1K1
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
|
||||
|
||||
struct DlCThreadTransfer
|
||||
struct DlEpilogue
|
||||
{
|
||||
std::array<size_t, 6> src_dst_access_order;
|
||||
size_t src_dst_vector_dim;
|
||||
size_t dst_scalar_per_vector;
|
||||
};
|
||||
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
|
||||
static_assert(ckb::DlEpilogueDescriptor<DlEpilogue>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
// Factory types
|
||||
|
||||
struct ThreadBlock_
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct XdlGemm_
|
||||
{
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemm_
|
||||
{
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BlockTransfer_
|
||||
{
|
||||
BlockTransferABC block_transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecialization_
|
||||
{
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
DlThreadConfig dl_thread_config;
|
||||
DlThreadCluster dl_thread_cluster;
|
||||
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
|
||||
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
|
||||
DlCThreadTransfer dl_c_thread_transfer;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
|
||||
struct Prefetch_
|
||||
{
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
|
||||
struct BlockGemm_
|
||||
{
|
||||
BlockGemm block_gemm;
|
||||
};
|
||||
|
||||
struct DlThreadConfig_
|
||||
{
|
||||
DlThreadConfig thread_config;
|
||||
};
|
||||
|
||||
struct DlThreadCluster_
|
||||
{
|
||||
DlThreadCluster thread_cluster;
|
||||
};
|
||||
|
||||
struct DlBlockTransfer_
|
||||
{
|
||||
DlBlockTransfer block_transfer_a;
|
||||
DlBlockTransfer block_transfer_b;
|
||||
};
|
||||
|
||||
struct DlEpilogue_
|
||||
{
|
||||
DlEpilogue epilogue_c;
|
||||
};
|
||||
|
||||
// Specialization wrapper for large tensor support
|
||||
template <typename BaseAlgorithm>
|
||||
struct LargeTensorWrapper
|
||||
{
|
||||
BaseAlgorithm base_algorithm;
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
// Factory
|
||||
|
||||
template <typename... Components>
|
||||
struct ConvAlgorithmTemplate : Components...
|
||||
{
|
||||
|
||||
template <typename TB>
|
||||
constexpr auto with_thread_block(const TB& tb) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ThreadBlock_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_block = tb;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
constexpr auto with_gemm_config(const GemmConfig& gemm) const
|
||||
{
|
||||
auto result = *this;
|
||||
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<WmmaGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BT>
|
||||
constexpr auto with_block_transfer(const BT& bt) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<BlockTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_transfer = bt;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_specializations(ConvFwdSpecialization fwd_spec,
|
||||
GemmSpecialization gemm_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecialization_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.fwd_specialization = fwd_spec;
|
||||
result.gemm_specialization = gemm_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
|
||||
size_t groups_to_merge,
|
||||
PipelineScheduler scheduler) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.num_gemm_k_prefetch_stages = k_prefetch_stages;
|
||||
result.num_groups_to_merge = groups_to_merge;
|
||||
result.loop_scheduler = scheduler;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_block_gemm(const BG& bg) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<BlockGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_gemm = bg;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename TC>
|
||||
constexpr auto with_dl_thread_config(const TC& tc) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlThreadConfig_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_config = tc;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename TCl>
|
||||
constexpr auto with_dl_thread_cluster(const TCl& tcl) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlThreadCluster_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_cluster = tcl;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BTA, typename BTB>
|
||||
constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlBlockTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_transfer_a = bta;
|
||||
result.block_transfer_b = btb;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_dl_epilogue(const DlEpilogue& epi) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlEpilogue_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.epilogue_c = epi;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Algorithm types
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
ConvSpecialization_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlBlockTransfer_,
|
||||
DlEpilogue_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -17,7 +17,6 @@ struct ConvSignature
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
GroupConvDeviceOp device_operation;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ struct ConvSignature
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
|
||||
@@ -1,383 +0,0 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test implementation
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
PipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.block_gemm = BlockGemmDesc};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == PipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == PipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == PipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == PipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 2,
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
{
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK()
|
||||
{
|
||||
// DL thread configuration
|
||||
constexpr DlThreadConfig DlThreadCfg{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
// DL thread cluster
|
||||
constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
// DL A block transfer - K0_M0_M1_K1 format
|
||||
constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL B block transfer - K0_N0_N1_K1 format
|
||||
constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL C thread transfer
|
||||
constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.dl_thread_config = DlThreadCfg,
|
||||
.dl_thread_cluster = DlCluster,
|
||||
.dl_block_transfer_a = DlBlockTransferA,
|
||||
.dl_block_transfer_b = DlBlockTransferB,
|
||||
.dl_c_thread_transfer = DlCTransfer};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// Note: Large_Tensor has identical parameters to regular XDL CShuffle
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
// Large_Tensor uses the same descriptor as regular XDL CShuffle
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 1,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(
|
||||
kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
184
experimental/builder/test/utils/ckb_conv_test_configs.hpp
Normal file
184
experimental/builder/test/utils/ckb_conv_test_configs.hpp
Normal file
@@ -0,0 +1,184 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
constexpr DlEpilogue DlEpilogueC{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x16x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x32x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
31
experimental/builder/test/utils/ckb_conv_test_utils.hpp
Normal file
31
experimental/builder/test/utils/ckb_conv_test_utils.hpp
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
using namespace test;
|
||||
|
||||
// Common test implementation
|
||||
template <typename Builder>
|
||||
constexpr void run_test(const std::vector<std::string>& kernel_instance_components)
|
||||
{
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
|
||||
for(const auto& component : kernel_instance_components)
|
||||
{
|
||||
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
Reference in New Issue
Block a user