Merge commit '13f6d635653bd5ffbfcac8577f1ef09590c23d78' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-05 03:38:26 +00:00
parent 5da2114921
commit eeadb34e8f
2 changed files with 186 additions and 279 deletions

View File

@@ -251,14 +251,10 @@ class ConvDescription : public Description
};
} // namespace conv
/// @brief Helper concept to detect if a type has ConvTraits specialization
template <typename T>
concept HasConvTraits = requires { typename conv::ConvTraits<T>; };
/// @brief Factory function to create ConvDescription from a convolution instance type
/// @tparam Instance The convolution instance type (must have InstanceTraits specialization)
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
/// @return A ConvDescription object populated with the instance's configuration details
template <HasConvTraits Instance>
template <conv::HasConvTraits Instance>
conv::ConvDescription describe()
{
using Traits = conv::ConvTraits<Instance>;

View File

@@ -21,6 +21,57 @@
namespace ck_tile::reflect::conv {
// Forward convolution layout concept - checks for A/B/E layout types
template <typename T>
concept HasFwdConvLayouts = requires {
typename T::ALayout;
typename T::BLayout;
typename T::ELayout;
};
// GEMM specialization concept - checks for kGemmSpecialization member
template <typename T>
concept HasGemmSpec = requires {
{
T::kGemmSpecialization
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
};
// Data types concept - checks for ADataType member
template <typename T>
concept HasDataTypes = requires { typename T::ADataType; };
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
template <typename T>
concept HasElementwiseOps = requires {
typename T::AElementwiseOperation;
typename T::BElementwiseOperation;
typename T::CDEElementwiseOperation;
};
// Tile parameters concept - checks for tile dimension and transfer members
template <typename T>
concept HasTileParams = requires {
{ T::kKPerBlock } -> std::convertible_to<int>;
{ T::kMPerBlock } -> std::convertible_to<int>;
{ T::kNPerBlock } -> std::convertible_to<int>;
{ T::kAK1 } -> std::convertible_to<int>;
{ T::kBK1 } -> std::convertible_to<int>;
T::kCThreadClusterLengths;
};
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
template <typename T>
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
HasElementwiseOps<T> && HasTileParams<T>;
// Primary concept for checking if a type can be described
// Currently only forward convolutions are supported, but this can be extended
// in the future to include backward data and backward weight convolutions
template <typename T>
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
// Helper metafunctions to convert from ck enums to builder enums
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
@@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version()
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v3)
return V3;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == v5)
return V5;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v3: return V3;
case v4: return V4;
case v5: return V5;
}
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
@@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == weight_only)
return WEIGHT_ONLY;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v4: return V4;
case weight_only: return WEIGHT_ONLY;
}
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
@@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Intrawave)
return INTRAWAVE;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
switch(ck_sched)
{
case Intrawave: return INTRAWAVE;
case Interwave: return INTERWAVE;
}
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
@@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Default)
return DEFAULT;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
switch(ck_sched)
{
case Default: return DEFAULT;
case Interwave: return INTERWAVE;
}
}
/// @brief Helper structures for organizing trait data with domain-specific naming
@@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction()
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
{
return builder::ConvDirection::FORWARD;
}
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
{
return builder::ConvDirection::BACKWARD_DATA;
}
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
{
return builder::ConvDirection::BACKWARD_WEIGHT;
}
else
{
return builder::ConvDirection::FORWARD; // Default fallback
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
@@ -242,60 +288,52 @@ constexpr auto conv_spec()
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
using enum builder::ConvFwdSpecialization;
if constexpr(InstTraits::kConvForwardSpecialization == Default)
switch(InstTraits::kConvForwardSpecialization)
{
return builder::ConvFwdSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
{
return builder::ConvFwdSpecialization::FILTER_3x3;
case Default: return DEFAULT;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter3x3: return FILTER_3x3;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
using enum builder::ConvBwdDataSpecialization;
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
switch(InstTraits::kConvBwdDataSpecialization)
{
return builder::ConvBwdDataSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
using enum builder::ConvBwdWeightSpecialization;
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
switch(InstTraits::kConvBwdWeightSpecialization)
{
return builder::ConvBwdWeightSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
{
return builder::ConvBwdWeightSpecialization::ODD_C;
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case OddC: return ODD_C;
}
}
}
// Helper variable template to check if CK layout enums match
template <typename A,
typename B,
typename E,
typename ExpectedA,
typename ExpectedB,
typename ExpectedE>
inline constexpr bool layouts_are =
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
@@ -304,112 +342,49 @@ constexpr auto conv_spec()
/// index 2 -> Output layout
template <typename Instance>
constexpr auto conv_layout()
requires HasFwdConvLayouts<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using ALayout = typename InstTraits::ALayout;
using BLayout = typename InstTraits::BLayout;
using ELayout = typename InstTraits::ELayout;
// Helper lambda to construct layout array
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
namespace ctc = ck::tensor_layout::convolution;
using A = typename InstanceTraits<Instance>::ALayout;
using B = typename InstanceTraits<Instance>::BLayout;
using E = typename InstanceTraits<Instance>::ELayout;
namespace ctl = ck::tensor_layout::convolution;
using enum builder::TensorLayout;
if constexpr(InstTraits::kSpatialDim == 1)
switch(InstanceTraits<Instance>::kSpatialDim)
{
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
std::is_same_v<ELayout, ctc::GNWK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNWC,
builder::TensorLayout::GKXC,
builder::TensorLayout::GNWK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NWGC,
builder::TensorLayout::GKXC,
builder::TensorLayout::NWGK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
builder::TensorLayout::GKXC,
builder::TensorLayout::NGKW};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
builder::TensorLayout::GKCX,
builder::TensorLayout::NGKW};
}
}
else if constexpr(InstTraits::kSpatialDim == 2)
{
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::GNHWK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNHWC,
builder::TensorLayout::GKYXC,
builder::TensorLayout::GNHWK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NHWGK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NHWGC,
builder::TensorLayout::GKYXC,
builder::TensorLayout::NHWGK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
builder::TensorLayout::GKYXC,
builder::TensorLayout::NGKHW};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKCYX> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
builder::TensorLayout::GKCYX,
builder::TensorLayout::NGKHW};
}
}
else if constexpr(InstTraits::kSpatialDim == 3)
{
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::GNDHWK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNDHWC,
builder::TensorLayout::GKZYXC,
builder::TensorLayout::GNDHWK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NDHWGK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NDHWGC,
builder::TensorLayout::GKZYXC,
builder::TensorLayout::NDHWGK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
builder::TensorLayout::GKZYXC,
builder::TensorLayout::NGKDHW};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKCZYX> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
builder::TensorLayout::GKCZYX,
builder::TensorLayout::NGKDHW};
}
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
return layouts(NWGC, GKXC, NWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
return layouts(NGCW, GKXC, NGKW);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
return layouts(NGCW, GKCX, NGKW);
break;
case 2:
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
return layouts(NGCHW, GKYXC, NGKHW);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
return layouts(NGCHW, GKCYX, NGKHW);
break;
case 3:
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
return layouts(NDHWGC, GKZYXC, NDHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
return layouts(NGCDHW, GKZYXC, NGKDHW);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
return layouts(NGCDHW, GKCZYX, NGKDHW);
break;
}
}
@@ -418,39 +393,26 @@ constexpr auto conv_layout()
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
template <typename Instance>
constexpr builder::DataType conv_data_type()
requires HasDataTypes<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
using enum builder::DataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
{
return builder::DataType::FP16;
}
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
{
return builder::DataType::BF16;
}
return BF16;
else if constexpr(std::is_same_v<ADataType, float>)
{
return builder::DataType::FP32;
}
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
{
return builder::DataType::FP8;
}
return FP8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
{
return builder::DataType::I8;
}
return I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
{
return builder::DataType::U8;
}
return U8;
else
{
// Default fallback
return builder::DataType::FP32;
}
return FP32; // Default fallback
}
/// @brief Derives the elementwise operation from op type.
@@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type()
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
using enum builder::ElementwiseOperation;
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
{
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
{
return builder::ElementwiseOperation::CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
{
return builder::ElementwiseOperation::SCALE;
}
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
{
return builder::ElementwiseOperation::PASS_THROUGH;
}
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
{
return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU;
}
return BIAS_BNORM_CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
return CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Scale"))
return SCALE;
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
return PASS_THROUGH;
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
return SCALEADD_SCALEADD_RELU;
}
/// @brief Derives a gemm padding from a kernel instance type.
@@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op()
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
requires HasGemmSpec<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
@@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec()
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
if constexpr(gemm_spec == Default)
switch(gemm_spec)
{
return DEFAULT;
}
else if constexpr(gemm_spec == MPadding)
{
return M_PADDING;
}
else if constexpr(gemm_spec == NPadding)
{
return N_PADDING;
}
else if constexpr(gemm_spec == KPadding)
{
return K_PADDING;
}
else if constexpr(gemm_spec == MNPadding)
{
return MN_PADDING;
}
else if constexpr(gemm_spec == MKPadding)
{
return MK_PADDING;
}
else if constexpr(gemm_spec == NKPadding)
{
return NK_PADDING;
}
else if constexpr(gemm_spec == MNKPadding)
{
return MNK_PADDING;
}
else if constexpr(gemm_spec == OPadding)
{
return O_PADDING;
}
else if constexpr(gemm_spec == MOPadding)
{
return MO_PADDING;
}
else if constexpr(gemm_spec == NOPadding)
{
return NO_PADDING;
}
else if constexpr(gemm_spec == KOPadding)
{
return KO_PADDING;
}
else if constexpr(gemm_spec == MNOPadding)
{
return MNO_PADDING;
}
else if constexpr(gemm_spec == MKOPadding)
{
return MKO_PADDING;
}
else if constexpr(gemm_spec == NKOPadding)
{
return NKO_PADDING;
}
else if constexpr(gemm_spec == MNKOPadding)
{
return MNKO_PADDING;
case Default: return DEFAULT;
case MPadding: return M_PADDING;
case NPadding: return N_PADDING;
case KPadding: return K_PADDING;
case MNPadding: return MN_PADDING;
case MKPadding: return MK_PADDING;
case NKPadding: return NK_PADDING;
case MNKPadding: return MNK_PADDING;
case OPadding: return O_PADDING;
case MOPadding: return MO_PADDING;
case NOPadding: return NO_PADDING;
case KOPadding: return KO_PADDING;
case MNOPadding: return MNO_PADDING;
case MKOPadding: return MKO_PADDING;
case NKOPadding: return NKO_PADDING;
case MNKOPadding: return MNKO_PADDING;
}
}
@@ -571,6 +481,7 @@ struct ConvTraits;
/// set of traits directly from a fully-formed device kernel `Instance` type.
/// It uses `InstanceTraits` to access the kernel's template parameters.
template <HasInstanceTraits Instance>
requires IsXdlFwdConv<InstanceTraits<Instance>>
struct ConvTraits<Instance>
{
using InstTraits = InstanceTraits<Instance>;