Fix clang-formatting.

This commit is contained in:
Ville Pietilä
2025-11-04 12:04:09 +00:00
parent 930dcaab25
commit c1db7497af
3 changed files with 158 additions and 206 deletions

View File

@@ -299,35 +299,23 @@ consteval BlockGemmSpec SetBlockGemm()
switch(BG.scheduler)
{
case BlockGemmPipelineScheduler::INTRAWAVE:
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
break;
case BlockGemmPipelineScheduler::INTERWAVE:
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
break;
default:
throw "Unknown BlockGemmPipelineScheduler";
case BlockGemmPipelineScheduler::INTRAWAVE:
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
break;
case BlockGemmPipelineScheduler::INTERWAVE:
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
break;
default: throw "Unknown BlockGemmPipelineScheduler";
}
switch(BG.pipeline_version)
{
case BlockGemmPipelineVersion::V1:
version = ck::BlockGemmPipelineVersion::v1;
break;
case BlockGemmPipelineVersion::V2:
version = ck::BlockGemmPipelineVersion::v2;
break;
case BlockGemmPipelineVersion::V3:
version = ck::BlockGemmPipelineVersion::v3;
break;
case BlockGemmPipelineVersion::V4:
version = ck::BlockGemmPipelineVersion::v4;
break;
case BlockGemmPipelineVersion::V5:
version = ck::BlockGemmPipelineVersion::v5;
break;
default:
throw "Unknown BlockGemmPipelineVersion";
case BlockGemmPipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break;
case BlockGemmPipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break;
case BlockGemmPipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
case BlockGemmPipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
case BlockGemmPipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
default: throw "Unknown BlockGemmPipelineVersion";
}
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
@@ -436,15 +424,12 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::LoopScheduler SetLoopScheduler()
{
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
using ck_loop_sched = ck::LoopScheduler;
using ck_loop_sched = ck::LoopScheduler;
switch(loop_scheduler)
{
case LoopScheduler::DEFAULT:
return ck_loop_sched::Default;
case LoopScheduler::INTERWAVE:
return ck_loop_sched::Interwave;
default:
throw "Unknown LoopScheduler";
case LoopScheduler::DEFAULT: return ck_loop_sched::Default;
case LoopScheduler::INTERWAVE: return ck_loop_sched::Interwave;
default: throw "Unknown LoopScheduler";
}
}
@@ -452,21 +437,16 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
using ck_pipeline = ck::PipelineVersion;
using ck_pipeline = ck::PipelineVersion;
switch(pipeline_version)
{
case GridwiseGemmPipelineVersion::V1:
return ck_pipeline::v1;
case GridwiseGemmPipelineVersion::V2:
return ck_pipeline::v2;
case GridwiseGemmPipelineVersion::V4:
return ck_pipeline::v4;
case GridwiseGemmPipelineVersion::WEIGHT_ONLY:
return ck_pipeline::weight_only;
case GridwiseGemmPipelineVersion::V3:
throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K.";
default:
throw "Unknown GridwiseGemmPipelineVersion";
case GridwiseGemmPipelineVersion::V1: return ck_pipeline::v1;
case GridwiseGemmPipelineVersion::V2: return ck_pipeline::v2;
case GridwiseGemmPipelineVersion::V4: return ck_pipeline::v4;
case GridwiseGemmPipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
case GridwiseGemmPipelineVersion::V3:
throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K.";
default: throw "Unknown GridwiseGemmPipelineVersion";
}
}
@@ -474,44 +454,27 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
{
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
switch(gemm_spec)
{
case GemmSpecialization::Default:
return ck_gemm_spec::Default;
case GemmSpecialization::MPadding:
return ck_gemm_spec::MPadding;
case GemmSpecialization::NPadding:
return ck_gemm_spec::NPadding;
case GemmSpecialization::KPadding:
return ck_gemm_spec::KPadding;
case GemmSpecialization::MNPadding:
return ck_gemm_spec::MNPadding;
case GemmSpecialization::MKPadding:
return ck_gemm_spec::MKPadding;
case GemmSpecialization::NKPadding:
return ck_gemm_spec::NKPadding;
case GemmSpecialization::MNKPadding:
return ck_gemm_spec::MNKPadding;
case GemmSpecialization::OPadding:
return ck_gemm_spec::OPadding;
case GemmSpecialization::MOPadding:
return ck_gemm_spec::MOPadding;
case GemmSpecialization::NOPadding:
return ck_gemm_spec::NOPadding;
case GemmSpecialization::KOPadding:
return ck_gemm_spec::KOPadding;
case GemmSpecialization::MNOPadding:
return ck_gemm_spec::MNOPadding;
case GemmSpecialization::MKOPadding:
return ck_gemm_spec::MKOPadding;
case GemmSpecialization::NKOPadding:
return ck_gemm_spec::NKOPadding;
case GemmSpecialization::MNKOPadding:
return ck_gemm_spec::MNKOPadding;
default:
throw "Unknown GemmSpecialization";
case GemmSpecialization::Default: return ck_gemm_spec::Default;
case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding;
case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding;
case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding;
case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding;
case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding;
case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding;
case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding;
case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding;
case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding;
case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding;
case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding;
case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding;
case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding;
case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding;
case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding;
default: throw "Unknown GemmSpecialization";
}
}
@@ -519,21 +482,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
using ck_block_gemm = ck::BlockGemmPipelineVersion;
using ck_block_gemm = ck::BlockGemmPipelineVersion;
switch(version)
{
case BlockGemmPipelineVersion::V1:
return ck_block_gemm::v1;
case BlockGemmPipelineVersion::V2:
return ck_block_gemm::v2;
case BlockGemmPipelineVersion::V3:
return ck_block_gemm::v3;
case BlockGemmPipelineVersion::V4:
return ck_block_gemm::v4;
case BlockGemmPipelineVersion::V5:
return ck_block_gemm::v5;
default:
throw "Unknown BlockGemmPipelineVersion";
case BlockGemmPipelineVersion::V1: return ck_block_gemm::v1;
case BlockGemmPipelineVersion::V2: return ck_block_gemm::v2;
case BlockGemmPipelineVersion::V3: return ck_block_gemm::v3;
case BlockGemmPipelineVersion::V4: return ck_block_gemm::v4;
case BlockGemmPipelineVersion::V5: return ck_block_gemm::v5;
default: throw "Unknown BlockGemmPipelineVersion";
}
}
@@ -541,19 +498,14 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
{
constexpr auto specialization = ALGORITHM.fwd_specialization;
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(specialization)
{
case ConvFwdSpecialization::DEFAULT:
return ck_conv_spec::Default;
case ConvFwdSpecialization::FILTER_1X1_PAD0:
return ck_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0:
return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3:
return ck_conv_spec::Filter3x3;
default:
throw "Unknown ConvFwdSpecialization";
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
default: throw "Unknown ConvFwdSpecialization";
}
}

View File

@@ -12,124 +12,125 @@ namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
constexpr BlockTransferABC FwdBlockTransfer_4x64_1 {.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr BlockTransferABC FwdBlockTransfer_4x64_1{
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr BlockTransferABC FwdBlockTransfer_4x16x1 {.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 16,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr BlockTransferABC FwdBlockTransfer_4x16x1{
.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 16,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr BlockTransferABC FwdBlockTransfer_4x32x1 {.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr BlockTransferABC FwdBlockTransfer_4x32x1{
.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave {.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_xdl = 32,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave {.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_xdl = 32,
.m_xdl_per_wave = 2,
.n_xdl_per_wave = 1};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave {.k1 = 8,
.m_per_wmma = 32,
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version = GridwiseGemmPipelineVersion::V1};
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
.m_per_wmma = 32,
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version =
GridwiseGemmPipelineVersion::V1};
constexpr ThreadBlock FwdThreadBlock_256x256x32 {.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_128x128x32 {.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_128x128x32{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_64x32x32 {.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_64x32x32{.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_64x64x64 {.block_size = 128,
.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr ThreadBlock FwdThreadBlock_64x64x64{.block_size = 128,
.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V1,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler =
BlockGemmPipelineScheduler::INTRAWAVE};
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V2,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler =
BlockGemmPipelineScheduler::INTRAWAVE};
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V3,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler =
BlockGemmPipelineScheduler::INTRAWAVE};
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V4,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler =
BlockGemmPipelineScheduler::INTRAWAVE};
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V5,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler =
BlockGemmPipelineScheduler::INTRAWAVE};
} // namespace ck_tile::builder::test_utils

View File

@@ -14,7 +14,6 @@ namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
// Common test implementation
template <typename Builder>
constexpr void run_test(const std::vector<std::string>& kernel_instance_components)
@@ -28,7 +27,7 @@ constexpr void run_test(const std::vector<std::string>& kernel_instance_componen
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
for (const auto& component : kernel_instance_components)
for(const auto& component : kernel_instance_components)
{
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
}