mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Fix clang-formatting.
This commit is contained in:
@@ -299,35 +299,23 @@ consteval BlockGemmSpec SetBlockGemm()
|
||||
|
||||
switch(BG.scheduler)
|
||||
{
|
||||
case BlockGemmPipelineScheduler::INTRAWAVE:
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
break;
|
||||
case BlockGemmPipelineScheduler::INTERWAVE:
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
break;
|
||||
default:
|
||||
throw "Unknown BlockGemmPipelineScheduler";
|
||||
case BlockGemmPipelineScheduler::INTRAWAVE:
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
break;
|
||||
case BlockGemmPipelineScheduler::INTERWAVE:
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
break;
|
||||
default: throw "Unknown BlockGemmPipelineScheduler";
|
||||
}
|
||||
|
||||
switch(BG.pipeline_version)
|
||||
{
|
||||
case BlockGemmPipelineVersion::V1:
|
||||
version = ck::BlockGemmPipelineVersion::v1;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V2:
|
||||
version = ck::BlockGemmPipelineVersion::v2;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V3:
|
||||
version = ck::BlockGemmPipelineVersion::v3;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V4:
|
||||
version = ck::BlockGemmPipelineVersion::v4;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V5:
|
||||
version = ck::BlockGemmPipelineVersion::v5;
|
||||
break;
|
||||
default:
|
||||
throw "Unknown BlockGemmPipelineVersion";
|
||||
case BlockGemmPipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break;
|
||||
case BlockGemmPipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break;
|
||||
case BlockGemmPipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case BlockGemmPipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case BlockGemmPipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
default: throw "Unknown BlockGemmPipelineVersion";
|
||||
}
|
||||
|
||||
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
|
||||
@@ -436,15 +424,12 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::LoopScheduler SetLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
using ck_loop_sched = ck::LoopScheduler;
|
||||
using ck_loop_sched = ck::LoopScheduler;
|
||||
switch(loop_scheduler)
|
||||
{
|
||||
case LoopScheduler::DEFAULT:
|
||||
return ck_loop_sched::Default;
|
||||
case LoopScheduler::INTERWAVE:
|
||||
return ck_loop_sched::Interwave;
|
||||
default:
|
||||
throw "Unknown LoopScheduler";
|
||||
case LoopScheduler::DEFAULT: return ck_loop_sched::Default;
|
||||
case LoopScheduler::INTERWAVE: return ck_loop_sched::Interwave;
|
||||
default: throw "Unknown LoopScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,21 +437,16 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
switch(pipeline_version)
|
||||
{
|
||||
case GridwiseGemmPipelineVersion::V1:
|
||||
return ck_pipeline::v1;
|
||||
case GridwiseGemmPipelineVersion::V2:
|
||||
return ck_pipeline::v2;
|
||||
case GridwiseGemmPipelineVersion::V4:
|
||||
return ck_pipeline::v4;
|
||||
case GridwiseGemmPipelineVersion::WEIGHT_ONLY:
|
||||
return ck_pipeline::weight_only;
|
||||
case GridwiseGemmPipelineVersion::V3:
|
||||
throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K.";
|
||||
default:
|
||||
throw "Unknown GridwiseGemmPipelineVersion";
|
||||
case GridwiseGemmPipelineVersion::V1: return ck_pipeline::v1;
|
||||
case GridwiseGemmPipelineVersion::V2: return ck_pipeline::v2;
|
||||
case GridwiseGemmPipelineVersion::V4: return ck_pipeline::v4;
|
||||
case GridwiseGemmPipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
case GridwiseGemmPipelineVersion::V3:
|
||||
throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K.";
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -474,44 +454,27 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
|
||||
{
|
||||
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
|
||||
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
|
||||
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
switch(gemm_spec)
|
||||
{
|
||||
case GemmSpecialization::Default:
|
||||
return ck_gemm_spec::Default;
|
||||
case GemmSpecialization::MPadding:
|
||||
return ck_gemm_spec::MPadding;
|
||||
case GemmSpecialization::NPadding:
|
||||
return ck_gemm_spec::NPadding;
|
||||
case GemmSpecialization::KPadding:
|
||||
return ck_gemm_spec::KPadding;
|
||||
case GemmSpecialization::MNPadding:
|
||||
return ck_gemm_spec::MNPadding;
|
||||
case GemmSpecialization::MKPadding:
|
||||
return ck_gemm_spec::MKPadding;
|
||||
case GemmSpecialization::NKPadding:
|
||||
return ck_gemm_spec::NKPadding;
|
||||
case GemmSpecialization::MNKPadding:
|
||||
return ck_gemm_spec::MNKPadding;
|
||||
case GemmSpecialization::OPadding:
|
||||
return ck_gemm_spec::OPadding;
|
||||
case GemmSpecialization::MOPadding:
|
||||
return ck_gemm_spec::MOPadding;
|
||||
case GemmSpecialization::NOPadding:
|
||||
return ck_gemm_spec::NOPadding;
|
||||
case GemmSpecialization::KOPadding:
|
||||
return ck_gemm_spec::KOPadding;
|
||||
case GemmSpecialization::MNOPadding:
|
||||
return ck_gemm_spec::MNOPadding;
|
||||
case GemmSpecialization::MKOPadding:
|
||||
return ck_gemm_spec::MKOPadding;
|
||||
case GemmSpecialization::NKOPadding:
|
||||
return ck_gemm_spec::NKOPadding;
|
||||
case GemmSpecialization::MNKOPadding:
|
||||
return ck_gemm_spec::MNKOPadding;
|
||||
default:
|
||||
throw "Unknown GemmSpecialization";
|
||||
case GemmSpecialization::Default: return ck_gemm_spec::Default;
|
||||
case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding;
|
||||
case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding;
|
||||
case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding;
|
||||
case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding;
|
||||
case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding;
|
||||
case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding;
|
||||
case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding;
|
||||
case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding;
|
||||
case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding;
|
||||
case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding;
|
||||
case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding;
|
||||
case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding;
|
||||
case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding;
|
||||
case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding;
|
||||
case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding;
|
||||
default: throw "Unknown GemmSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -519,21 +482,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
using ck_block_gemm = ck::BlockGemmPipelineVersion;
|
||||
using ck_block_gemm = ck::BlockGemmPipelineVersion;
|
||||
switch(version)
|
||||
{
|
||||
case BlockGemmPipelineVersion::V1:
|
||||
return ck_block_gemm::v1;
|
||||
case BlockGemmPipelineVersion::V2:
|
||||
return ck_block_gemm::v2;
|
||||
case BlockGemmPipelineVersion::V3:
|
||||
return ck_block_gemm::v3;
|
||||
case BlockGemmPipelineVersion::V4:
|
||||
return ck_block_gemm::v4;
|
||||
case BlockGemmPipelineVersion::V5:
|
||||
return ck_block_gemm::v5;
|
||||
default:
|
||||
throw "Unknown BlockGemmPipelineVersion";
|
||||
case BlockGemmPipelineVersion::V1: return ck_block_gemm::v1;
|
||||
case BlockGemmPipelineVersion::V2: return ck_block_gemm::v2;
|
||||
case BlockGemmPipelineVersion::V3: return ck_block_gemm::v3;
|
||||
case BlockGemmPipelineVersion::V4: return ck_block_gemm::v4;
|
||||
case BlockGemmPipelineVersion::V5: return ck_block_gemm::v5;
|
||||
default: throw "Unknown BlockGemmPipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -541,19 +498,14 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvFwdSpecialization::DEFAULT:
|
||||
return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0:
|
||||
return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0:
|
||||
return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3:
|
||||
return ck_conv_spec::Filter3x3;
|
||||
default:
|
||||
throw "Unknown ConvFwdSpecialization";
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,124 +12,125 @@ namespace ck_tile::builder::test_utils {
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64_1 {.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x64_1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x16x1 {.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x16x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x32x1 {.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
constexpr BlockTransferABC FwdBlockTransfer_4x32x1{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave {.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave {.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 1};
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave {.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = GridwiseGemmPipelineVersion::V1};
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version =
|
||||
GridwiseGemmPipelineVersion::V1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256x256x32 {.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr ThreadBlock FwdThreadBlock_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128x128x32 {.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr ThreadBlock FwdThreadBlock_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_64x32x32 {.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr ThreadBlock FwdThreadBlock_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_64x64x64 {.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr ThreadBlock FwdThreadBlock_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V1,
|
||||
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
.scheduler =
|
||||
BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V2,
|
||||
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
.scheduler =
|
||||
BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V3,
|
||||
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
.scheduler =
|
||||
BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V4,
|
||||
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
.scheduler =
|
||||
BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V5,
|
||||
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
.scheduler =
|
||||
BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
@@ -14,7 +14,6 @@ namespace ck_tile::builder::test_utils {
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
|
||||
// Common test implementation
|
||||
template <typename Builder>
|
||||
constexpr void run_test(const std::vector<std::string>& kernel_instance_components)
|
||||
@@ -28,7 +27,7 @@ constexpr void run_test(const std::vector<std::string>& kernel_instance_componen
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
|
||||
for (const auto& component : kernel_instance_components)
|
||||
for(const auto& component : kernel_instance_components)
|
||||
{
|
||||
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user