Update conv specialization enum.

This commit is contained in:
Ville Pietilä
2025-12-29 05:06:39 -05:00
parent 30a9686877
commit 027d943b2f
16 changed files with 21 additions and 21 deletions

View File

@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v2_intrawave);

View File

@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -32,7 +32,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_128_64x64x64)
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
.with_transfer(Transfer_4x32x1)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -67,7 +67,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v5_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -35,7 +35,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -26,7 +26,7 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
.with_thread_block(ThreadBlock_256_128x128x16)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
.with_dl_thread_cluster(DlThreadCluster_8x2)
.with_dl_transfer(DlFwdTransfer);
@@ -60,7 +60,7 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
.with_thread_block(ThreadBlock_256_128x128x16)
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
.with_dl_thread_cluster(DlThreadCluster_8x2)

View File

@@ -24,7 +24,7 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd
.with_thread_block(cku::ThreadBlock_256_256x256x32)
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::Transfer_4x64x1)
.with_fwd_specializations(ckb::ConvFwdSpecialization::DEFAULT,
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);

View File

@@ -29,7 +29,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v4_intrawave);

View File

@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
.with_transfer(Transfer_4x64x1_fp8)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT,
.with_fwd_specializations(ConvSpecialization::DEFAULT,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
@@ -67,7 +67,7 @@ TEST(
.with_thread_block(ThreadBlock_128_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
.with_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};

View File

@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v3_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v4_intrawave);

View File

@@ -30,7 +30,7 @@ TEST(FwdConvInstances,
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(Transfer_4x64x1)
.with_fwd-specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
.with_fwd-specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);

View File

@@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
@@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
@@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);

View File

@@ -148,7 +148,7 @@ struct DefaultAlgorithm
},
};
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};

View File

@@ -79,7 +79,7 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization)
constexpr struct Algorithm
{
ckb::ConvFwdSpecialization fwd_specialization =
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
} kAlgorithm;
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();