mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Problem descriptor
This commit is contained in:
@@ -4,19 +4,19 @@
|
||||
|
||||
#include "convolution_builder.hpp"
|
||||
|
||||
// Example of problem description for Forward Conv with default settings
|
||||
// Example of kernel description for Forward Conv with default settings
|
||||
struct GroupedConvFwdXdlImplicitGemm : public GroupedConvBaseXdlV1 {
|
||||
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::Forward;
|
||||
static constexpr ElementwiseOperation ElementwiseOperation_ = ElementwiseOperation::Bias;
|
||||
};
|
||||
|
||||
// Example of problem description for Backward Weight Conv with default settings and Split K Two Stage
|
||||
// Example of kernel description for Backward Weight Conv with default settings and Split K Two Stage
|
||||
struct GroupedConvBwdWeightXdlImplicitGemmTwoStage : public GroupedConvBaseXdlV1 {
|
||||
static constexpr ConvolutionDirection ConvolutionDirection_ = ConvolutionDirection::BackwardWeight;
|
||||
static constexpr SplitKSupport SplitKSupport_ = SplitKSupport::SupportedTwoStage;
|
||||
};
|
||||
|
||||
struct ImplementationDescriptor : public NHWCImplementationBaseV1, public BF16ImplementationBaseV1 {
|
||||
struct Implementation16x16 : ImplementationDefaultV1 {
|
||||
static constexpr ck::index_t BlockSize_ = 64;
|
||||
static constexpr auto TileSizes_ = std::make_tuple(16, 16, 32);
|
||||
static constexpr ck::index_t K1_ = 8;
|
||||
@@ -26,10 +26,12 @@ struct ImplementationDescriptor : public NHWCImplementationBaseV1, public BF16Im
|
||||
static constexpr auto LDSStoreVectorSize_ = std::make_tuple(4, 4);
|
||||
};
|
||||
|
||||
struct ProblemBF16NHWGC : public BF16ProblemBaseV1, public NHWGCProblemBaseV1 {};
|
||||
|
||||
int main () {
|
||||
ConvolutionBuilder<GroupedConvFwdXdlImplicitGemm, ImplementationDescriptor> builder_fwd;
|
||||
ConvolutionBuilder<GroupedConvFwdXdlImplicitGemm, ProblemBF16NHWGC, Implementation16x16> builder_fwd;
|
||||
std::cout << builder_fwd.GetInstanceName() << std::endl;
|
||||
ConvolutionBuilder<GroupedConvBwdWeightXdlImplicitGemmTwoStage, ImplementationDescriptor> builder_bwd_weight_two_stage;
|
||||
ConvolutionBuilder<GroupedConvBwdWeightXdlImplicitGemmTwoStage, ProblemBF16NHWGC, Implementation16x16> builder_bwd_weight_two_stage;
|
||||
std::cout << builder_bwd_weight_two_stage.GetInstanceName() << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user