mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Update conv specialization enum.
This commit is contained in:
@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave);
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -32,7 +32,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x32x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
@@ -67,7 +67,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -35,7 +35,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -26,7 +26,7 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(ThreadBlock_256_128x128x16)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
@@ -60,7 +60,7 @@ TEST(FwdConvInstances,
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(ThreadBlock_256_128x128x16)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
|
||||
@@ -24,7 +24,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd
|
||||
.with_thread_block(cku::ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::Transfer_4x64x1)
|
||||
.with_fwd_specializations(ckb::ConvFwdSpecialization::DEFAULT,
|
||||
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
|
||||
ckb::GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
|
||||
.with_transfer(Transfer_4x64x1_fp8)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT,
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
@@ -67,7 +67,7 @@ TEST(
|
||||
.with_thread_block(ThreadBlock_128_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
.with_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd-specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
.with_fwd-specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
@@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
@@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
@@ -148,7 +148,7 @@ struct DefaultAlgorithm
|
||||
},
|
||||
};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
|
||||
|
||||
@@ -79,7 +79,7 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization)
|
||||
constexpr struct Algorithm
|
||||
{
|
||||
ckb::ConvFwdSpecialization fwd_specialization =
|
||||
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
} kAlgorithm;
|
||||
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user