mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
Merge branch 'develop' into jakpiase/gemm_pipeline_mem_skip_lds
This commit is contained in:
2
.github/workflows/therock-ci-linux.yml
vendored
2
.github/workflows/therock-ci-linux.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: 409f43ad9d564454bb1b23f8c8aa15d6b9d25200
|
||||
ref: dc05d637054ad197c84b00e24b6262af0ec797c6 # 10-03-2025 commit
|
||||
path: "TheRock"
|
||||
|
||||
- name: Runner Health Settings
|
||||
|
||||
@@ -252,15 +252,14 @@ struct SplitKTwoStageInvoker
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -275,30 +275,29 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
// For workspace mode, always use SET operation since each K-split writes to separate memory
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -187,15 +187,14 @@ struct UniversalInvoker
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -70,99 +70,95 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName()
|
||||
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
|
||||
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -243,31 +239,28 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
if(!splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
return ave_time =
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -76,99 +76,95 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName()
|
||||
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
|
||||
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -109,23 +109,19 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
|
||||
return ave_time;
|
||||
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
#include "quant_run_grouped_gemm_example.inc"
|
||||
|
||||
@@ -167,38 +167,38 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
|
||||
@@ -196,7 +196,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
|
||||
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(is_same<CLayout, DLayout>::value)
|
||||
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
|
||||
else
|
||||
@@ -253,7 +253,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
DsLengths[i] = out_lengths;
|
||||
|
||||
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
DsStrides[i] = {arg.StrideDs[i], 1};
|
||||
|
||||
@@ -433,8 +433,13 @@ struct CShuffleEpilogue
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
|
||||
|
||||
static_assert(MPerXdl % RowsPerLane == 0,
|
||||
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
|
||||
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM2 = RowsPerLane;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
@@ -515,32 +520,25 @@ struct CShuffleEpilogue
|
||||
// Pack 4 “rows per lane” as you already do
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// source indices in shuffle_acc: (n_idx * product(Y) + row)
|
||||
const index_t base = n_idx * c_warp_y_lengths.product();
|
||||
const index_t plane = c_warp_y_lengths.product();
|
||||
|
||||
// local lambda to fuse scale (if present) and convert
|
||||
auto emit = [&](index_t out_idx, index_t src_row) {
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
|
||||
|
||||
static_for<0, kM2, 1>{}([&](auto m_lane) {
|
||||
const int src = n_idx * plane + m_lane; // source row in this N-plane
|
||||
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[src];
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
|
||||
v = static_cast<AccDataType>(v * s_m * s_n);
|
||||
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
|
||||
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
|
||||
v = static_cast<AccDataType>(v * sm * sn);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
|
||||
};
|
||||
|
||||
// Your current packing pattern (rows 0..3, spaced by NRepeat)
|
||||
emit(n_idx + 0 * NRepeat, 0);
|
||||
emit(n_idx + 1 * NRepeat, 1);
|
||||
emit(n_idx + 2 * NRepeat, 2);
|
||||
emit(n_idx + 3 * NRepeat, 3);
|
||||
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
|
||||
});
|
||||
});
|
||||
|
||||
// store/update
|
||||
|
||||
@@ -1134,6 +1134,7 @@ struct UniversalGemmKernel
|
||||
|
||||
while(block_id < num_work)
|
||||
{
|
||||
s_waitcnt_barrier();
|
||||
// Get the tile index for this block
|
||||
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
|
||||
|
||||
@@ -30,3 +30,4 @@ add_subdirectory(reduce)
|
||||
add_subdirectory(epilogue)
|
||||
add_subdirectory(atomic_add_op)
|
||||
add_subdirectory(fmha)
|
||||
add_subdirectory(gemm_tile_engine)
|
||||
|
||||
237
test/ck_tile/gemm_tile_engine/CMakeLists.txt
Normal file
237
test/ck_tile/gemm_tile_engine/CMakeLists.txt
Normal file
@@ -0,0 +1,237 @@
|
||||
# ============================================================================
|
||||
# GEMM Tile Engine Unit Tests
|
||||
#
|
||||
# This CMake file creates unit tests for tile_engine generated GEMM kernels.
|
||||
# It follows the exact same build patterns as tile_engine for consistency
|
||||
# and reliability. Each kernel configuration gets its own test executable.
|
||||
# ============================================================================
|
||||
|
||||
# Locate tile_engine GEMM scripts directory
|
||||
set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm")
|
||||
|
||||
if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR})
|
||||
message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# create_individual_gemm_test_target
|
||||
#
|
||||
# Creates a single test executable for a specific kernel configuration.
|
||||
# Mirrors tile_engine's create_individual_gemm_target function for consistency.
|
||||
#
|
||||
# Parameters:
|
||||
# datatype - Data type (fp16, bf16, fp32, etc.)
|
||||
# layout - Matrix layout (rcr, rrr, ccr, crr)
|
||||
# config_name - Configuration file name without .json extension
|
||||
# trait - Kernel trait combination string
|
||||
# tile_config - Tile configuration parameters
|
||||
# config_json - Full path to JSON configuration file
|
||||
# ============================================================================
|
||||
function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json)
|
||||
set(target_name "test_gemm_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
|
||||
|
||||
# Generated header path for this specific kernel configuration
|
||||
set(test_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Generate kernel header using tile_engine's Python script
|
||||
add_custom_command(
|
||||
OUTPUT ${test_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "test_gemm_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
DEPENDS ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py ${config_json}
|
||||
COMMENT "Generating test header ${test_header}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Create GTest executable for this kernel configuration
|
||||
add_gtest_executable(${target_name}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp
|
||||
)
|
||||
|
||||
# Ensure header is generated before compilation
|
||||
set(header_target "${target_name}_header")
|
||||
add_custom_target(${header_target} DEPENDS ${test_header})
|
||||
add_dependencies(${target_name} ${header_target})
|
||||
|
||||
# Configure GPU architectures for HIP compilation
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS})
|
||||
|
||||
# Define preprocessor macros for generated header location
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_SINGLE_INSTANCE_HPP="${test_header}"
|
||||
)
|
||||
|
||||
# Include directories for headers and dependencies
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR} # Root directory for tile_engine access
|
||||
${GTEST_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
# Compiler options matching tile_engine requirements
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template # Suppress template warnings
|
||||
-Wno-float-equal # Allow floating point comparisons
|
||||
--offload-compress # Enable GPU code compression
|
||||
-include ${test_header} # Auto-include generated header
|
||||
)
|
||||
|
||||
message(STATUS " Created test target: ${target_name}")
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# build_gemm_test_targets
|
||||
#
|
||||
# Builds all test targets for a specific datatype/layout/config combination.
|
||||
# Uses tile_engine's two-step process: list kernels, then generate tests.
|
||||
#
|
||||
# Parameters:
|
||||
# datatype - Data type (fp16, bf16, fp32, etc.)
|
||||
# layout - Matrix layout (rcr, rrr, ccr, crr)
|
||||
# config_name - Configuration file name without .json extension
|
||||
# ============================================================================
|
||||
function(build_gemm_test_targets datatype layout config_name)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
|
||||
|
||||
# Locate and validate configuration file
|
||||
set(config_filename "${config_name}.json")
|
||||
set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")
|
||||
message(STATUS " Using test config: ${config_filename}")
|
||||
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(WARNING "Test config file not found: ${json_blob}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Prepare build directory for this configuration
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# STEP 1: Discovery phase - list all valid kernel configurations
|
||||
message(STATUS " Listing kernel configurations for ${datatype}_${layout}...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(WARNING "Failed to list kernels for ${datatype}_${layout}: ${list_error}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Validate kernel discovery results
|
||||
if(EXISTS ${working_path}/gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(STATUS " Found ${kernel_count} test configurations for ${datatype}_${layout}")
|
||||
else()
|
||||
message(WARNING "Kernel count file not found for ${datatype}_${layout}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# STEP 2: Generation phase - create test targets for each discovered kernel
|
||||
if(EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
set(test_count 0)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse kernel specification format: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Generate test target for this kernel configuration
|
||||
create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
math(EXPR test_count "${test_count} + 1")
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}")
|
||||
else()
|
||||
message(WARNING "Kernel list file not found for ${datatype}_${layout}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
# MAIN EXECUTION - Test Target Generation
|
||||
# ============================================================================
|
||||
|
||||
message(STATUS "=== Starting GEMM Tile Engine Test Configuration ===")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# GPU architecture filtering - only build tests for supported architectures
|
||||
set(GEMM_TEST_GPU_TARGETS "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_TEST_GPU_TARGETS ${target})
|
||||
message(STATUS " Adding GPU target for tests: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Early exit if no compatible GPU architectures are available
|
||||
if(NOT GEMM_TEST_GPU_TARGETS)
|
||||
message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}")
|
||||
|
||||
# ============================================================================
|
||||
# Test Configuration Matrix
|
||||
# ============================================================================
|
||||
|
||||
# Available test configurations (minimal set for fast CI/testing)
|
||||
set(TEST_CONFIGS
|
||||
"simple_test_config"
|
||||
# "medium_tiles_config" # Uncomment for broader testing
|
||||
)
|
||||
|
||||
# Data types for testing (core precision types)
|
||||
set(TEST_DATATYPES "fp16" "bf16")
|
||||
# Extended data type options:
|
||||
# set(TEST_DATATYPES "fp16" "bf16" "fp32" "fp64" "int8")
|
||||
|
||||
# Matrix layouts for testing (row-column-row is most common)
|
||||
set(TEST_LAYOUTS "rcr")
|
||||
# Extended layout options:
|
||||
# set(TEST_LAYOUTS "rcr" "rrr" "ccr" "crr")
|
||||
|
||||
# ============================================================================
|
||||
# Test Target Generation Loop
|
||||
# ============================================================================
|
||||
|
||||
foreach(datatype IN LISTS TEST_DATATYPES)
|
||||
foreach(layout IN LISTS TEST_LAYOUTS)
|
||||
foreach(config IN LISTS TEST_CONFIGS)
|
||||
set(CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config}.json")
|
||||
if(EXISTS ${CONFIG_FILE})
|
||||
message(STATUS "Building tests for ${datatype}_${layout}_${config}")
|
||||
build_gemm_test_targets("${datatype}" "${layout}" "${config}")
|
||||
else()
|
||||
message(WARNING "Config file not found: ${CONFIG_FILE}")
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
message(STATUS "GEMM tile engine tests configured for ${TEST_DATATYPES} with ${TEST_LAYOUTS} layouts using ${TEST_CONFIGS} configurations")
|
||||
27
test/ck_tile/gemm_tile_engine/README.md
Normal file
27
test/ck_tile/gemm_tile_engine/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# GEMM Tile Engine Unit Tests
|
||||
|
||||
## How It Works
|
||||
|
||||
This unit test system integrates **tile_engine's kernel generation** into automated testing:
|
||||
|
||||
1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels
|
||||
2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine)
|
||||
3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers
|
||||
4. **Individual test executables**: Each kernel configuration becomes a separate test
|
||||
5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine
|
||||
|
||||
## Tile Engine Integration
|
||||
|
||||
```
|
||||
JSON Config → tile_engine Python scripts → Generated Headers → Test Executables
|
||||
```
|
||||
|
||||
- **`--list_kernels`**: Get available kernel configurations from JSON
|
||||
- **`--gen_single`**: Generate individual kernel header for each configuration
|
||||
- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations
|
||||
- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching
|
||||
|
||||
|
||||
|
||||
|
||||
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.
|
||||
@@ -0,0 +1,89 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
223
test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp
Normal file
223
test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Unit tests for tile_engine generated GEMM kernels
|
||||
// Tests kernel correctness using tile_engine's verification methodology
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "tile_engine/ops/gemm/gemm_common.hpp"
|
||||
|
||||
// The kernel header is included via compile command line with -include flag
|
||||
// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types
|
||||
|
||||
// Adaptive error threshold calculation matching tile_engine's implementation
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations (from tile_engine)
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
bool compare_results(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Test parameter structure for matrix dimensions and split_k values
|
||||
struct GemmTestParams
|
||||
{
|
||||
int m, n, k, split_k;
|
||||
};
|
||||
|
||||
class GemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
auto params = GetParam();
|
||||
m_ = params.m;
|
||||
n_ = params.n;
|
||||
k_ = params.k;
|
||||
split_k_ = params.split_k;
|
||||
|
||||
// Calculate strides (following tile_engine pattern)
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_a_ = k_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_a_ = m_;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_b_ = n_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_b_ = k_;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_c_ = n_;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_c_ = m_;
|
||||
}
|
||||
}
|
||||
|
||||
// Test dimensions
|
||||
int m_, n_, k_, split_k_;
|
||||
int stride_a_, stride_b_, stride_c_;
|
||||
};
|
||||
|
||||
TEST_P(GemmTileEngineTest, BasicFunctionality)
|
||||
{
|
||||
// Get tensor layouts from generated kernel
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
// Use split_k from test parameters
|
||||
int split_k = split_k_;
|
||||
int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a));
|
||||
int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b));
|
||||
int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c));
|
||||
|
||||
// Create host tensors with proper descriptors
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
|
||||
|
||||
// Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine)
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
|
||||
// Allocate GPU device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
// Copy data to device and zero output buffer
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
// Calculate reference result on host for verification
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_result);
|
||||
|
||||
// Create GEMM kernel arguments
|
||||
ck_tile::GemmHostArgs gemm_args(a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
split_k,
|
||||
m_,
|
||||
n_,
|
||||
k_,
|
||||
stride_a_calc,
|
||||
stride_b_calc,
|
||||
stride_c_calc);
|
||||
|
||||
// Configure kernel execution for maximum speed (no timing, no debug output)
|
||||
ck_tile::stream_config stream_config{nullptr, // stream
|
||||
false, // time_kernel (disable timing for speed)
|
||||
0, // log_level (disable debug output)
|
||||
0, // n_warmup
|
||||
1, // n_repeat
|
||||
false, // is_gpu_timer (unused when time_kernel=false)
|
||||
false, // flush_cache
|
||||
1}; // rotating_count
|
||||
|
||||
// Launch the generated kernel (no timing overhead for fastest execution)
|
||||
try
|
||||
{
|
||||
SelectedKernel::launch(gemm_args, stream_config);
|
||||
// Kernel launched successfully if no exception thrown
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
FAIL() << "Kernel launch failed: " << e.what();
|
||||
}
|
||||
|
||||
// Copy result back from device
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
// Verify results using tile_engine's adaptive error thresholds
|
||||
bool verification_passed = compare_results<ADataType, BDataType, AccDataType, CDataType>(
|
||||
KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result);
|
||||
|
||||
EXPECT_TRUE(verification_passed) << "GEMM result verification failed";
|
||||
}
|
||||
|
||||
TEST_P(GemmTileEngineTest, KernelInfo)
|
||||
{
|
||||
// Simple test to verify kernel information is available
|
||||
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
|
||||
|
||||
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
|
||||
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Define test parameters for GEMM verification
|
||||
INSTANTIATE_TEST_SUITE_P(GemmVerification,
|
||||
GemmTileEngineTest,
|
||||
::testing::Values(GemmTestParams{256, 256, 128, 1},
|
||||
GemmTestParams{256, 256, 1024, 1},
|
||||
GemmTestParams{256, 512, 512, 1},
|
||||
GemmTestParams{512, 256, 512, 1}),
|
||||
[](const ::testing::TestParamInfo<GemmTestParams>& param_info) {
|
||||
return std::to_string(param_info.param.m) + "x" +
|
||||
std::to_string(param_info.param.n) + "x" +
|
||||
std::to_string(param_info.param.k) + "_splitk" +
|
||||
std::to_string(param_info.param.split_k);
|
||||
});
|
||||
Reference in New Issue
Block a user