Merge branch 'develop' into jakpiase/gemm_pipeline_mem_skip_lds

This commit is contained in:
Damien Lejeune
2025-10-08 06:12:59 +00:00
16 changed files with 805 additions and 247 deletions

View File

@@ -41,7 +41,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
ref: 409f43ad9d564454bb1b23f8c8aa15d6b9d25200
ref: dc05d637054ad197c84b00e24b6262af0ec797c6 # 10-03-2025 commit
path: "TheRock"
- name: Runner Health Settings

View File

@@ -252,15 +252,14 @@ struct SplitKTwoStageInvoker
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
};

View File

@@ -275,30 +275,29 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
// For workspace mode, always use SET operation since each K-split writes to separate memory
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
/**

View File

@@ -187,15 +187,14 @@ struct UniversalInvoker
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
};

View File

@@ -70,99 +70,95 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename GemmConfig,
@@ -243,31 +239,28 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
if(!splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
return ave_time =
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
return ave_time;
}
#include "run_grouped_gemm_example.inc"

View File

@@ -76,99 +76,95 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
#include "run_grouped_gemm_example.inc"

View File

@@ -109,23 +109,19 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return ave_time;
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
#include "quant_run_grouped_gemm_example.inc"

View File

@@ -167,38 +167,38 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <template <typename PreType> typename FlatmmConfig>

View File

@@ -196,7 +196,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
[](auto i) {
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same<CLayout, DLayout>::value)
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
else
@@ -253,7 +253,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
static_for<0, NumDTensor, 1>{}([&](auto i) {
DsLengths[i] = out_lengths;
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
{
DsStrides[i] = {arg.StrideDs[i], 1};

View File

@@ -433,8 +433,13 @@ struct CShuffleEpilogue
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
static_assert(MPerXdl % RowsPerLane == 0,
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
constexpr int kM2 = RowsPerLane;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
@@ -515,32 +520,25 @@ struct CShuffleEpilogue
// Pack 4 “rows per lane” as you already do
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// source indices in shuffle_acc: (n_idx * product(Y) + row)
const index_t base = n_idx * c_warp_y_lengths.product();
const index_t plane = c_warp_y_lengths.product();
// local lambda to fuse scale (if present) and convert
auto emit = [&](index_t out_idx, index_t src_row) {
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
static_for<0, kM2, 1>{}([&](auto m_lane) {
const int src = n_idx * plane + m_lane; // source row in this N-plane
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
AccDataType v = shuffle_acc.get_thread_buffer()[src];
if constexpr(has_scalar_scales)
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales && !has_scalar_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
v = static_cast<AccDataType>(v * s_m * s_n);
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
v = static_cast<AccDataType>(v * sm * sn);
}
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
};
// Your current packing pattern (rows 0..3, spaced by NRepeat)
emit(n_idx + 0 * NRepeat, 0);
emit(n_idx + 1 * NRepeat, 1);
emit(n_idx + 2 * NRepeat, 2);
emit(n_idx + 3 * NRepeat, 3);
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
});
});
// store/update

View File

@@ -1134,6 +1134,7 @@ struct UniversalGemmKernel
while(block_id < num_work)
{
s_waitcnt_barrier();
// Get the tile index for this block
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);

View File

@@ -30,3 +30,4 @@ add_subdirectory(reduce)
add_subdirectory(epilogue)
add_subdirectory(atomic_add_op)
add_subdirectory(fmha)
add_subdirectory(gemm_tile_engine)

View File

@@ -0,0 +1,237 @@
# ============================================================================
# GEMM Tile Engine Unit Tests
#
# This CMake file creates unit tests for tile_engine generated GEMM kernels.
# It follows the exact same build patterns as tile_engine for consistency
# and reliability. Each kernel configuration gets its own test executable.
# ============================================================================
# Locate tile_engine GEMM scripts directory
set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm")
if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR})
message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}")
return()
endif()
# ============================================================================
# create_individual_gemm_test_target
#
# Creates a single test executable for a specific kernel configuration.
# Mirrors tile_engine's create_individual_gemm_target function for consistency.
#
# Parameters:
# datatype - Data type (fp16, bf16, fp32, etc.)
# layout - Matrix layout (rcr, rrr, ccr, crr)
# config_name - Configuration file name without .json extension
# trait - Kernel trait combination string
# tile_config - Tile configuration parameters
# config_json - Full path to JSON configuration file
# ============================================================================
function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json)
set(target_name "test_gemm_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
# Generated header path for this specific kernel configuration
set(test_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Generate kernel header using tile_engine's Python script
add_custom_command(
OUTPUT ${test_header}
COMMAND ${Python3_EXECUTABLE} ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "test_gemm_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
DEPENDS ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py ${config_json}
COMMENT "Generating test header ${test_header}"
VERBATIM
)
# Create GTest executable for this kernel configuration
add_gtest_executable(${target_name}
${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp
)
# Ensure header is generated before compilation
set(header_target "${target_name}_header")
add_custom_target(${header_target} DEPENDS ${test_header})
add_dependencies(${target_name} ${header_target})
# Configure GPU architectures for HIP compilation
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS})
# Define preprocessor macros for generated header location
target_compile_definitions(${target_name} PRIVATE
GEMM_SINGLE_INSTANCE_HPP="${test_header}"
)
# Include directories for headers and dependencies
target_include_directories(${target_name} PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR} # Root directory for tile_engine access
${GTEST_INCLUDE_DIRS}
)
# Compiler options matching tile_engine requirements
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template # Suppress template warnings
-Wno-float-equal # Allow floating point comparisons
--offload-compress # Enable GPU code compression
-include ${test_header} # Auto-include generated header
)
message(STATUS " Created test target: ${target_name}")
endfunction()
# ============================================================================
# build_gemm_test_targets
#
# Builds all test targets for a specific datatype/layout/config combination.
# Uses tile_engine's two-step process: list kernels, then generate tests.
#
# Parameters:
# datatype - Data type (fp16, bf16, fp32, etc.)
# layout - Matrix layout (rcr, rrr, ccr, crr)
# config_name - Configuration file name without .json extension
# ============================================================================
function(build_gemm_test_targets datatype layout config_name)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
# Locate and validate configuration file
set(config_filename "${config_name}.json")
set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")
message(STATUS " Using test config: ${config_filename}")
if(NOT EXISTS ${json_blob})
message(WARNING "Test config file not found: ${json_blob}")
return()
endif()
# Prepare build directory for this configuration
file(MAKE_DIRECTORY ${working_path})
# STEP 1: Discovery phase - list all valid kernel configurations
message(STATUS " Listing kernel configurations for ${datatype}_${layout}...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(WARNING "Failed to list kernels for ${datatype}_${layout}: ${list_error}")
return()
endif()
# Validate kernel discovery results
if(EXISTS ${working_path}/gemm_kernel_count.txt)
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(STATUS " Found ${kernel_count} test configurations for ${datatype}_${layout}")
else()
message(WARNING "Kernel count file not found for ${datatype}_${layout}")
return()
endif()
# STEP 2: Generation phase - create test targets for each discovered kernel
if(EXISTS ${working_path}/gemm_kernel_list.txt)
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
set(test_count 0)
foreach(line IN LISTS kernel_lines)
# Parse kernel specification format: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(LENGTH parts parts_len)
if(parts_len EQUAL 3)
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Generate test target for this kernel configuration
create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
math(EXPR test_count "${test_count} + 1")
endif()
endforeach()
message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}")
else()
message(WARNING "Kernel list file not found for ${datatype}_${layout}")
endif()
endfunction()
# ============================================================================
# MAIN EXECUTION - Test Target Generation
# ============================================================================
message(STATUS "=== Starting GEMM Tile Engine Test Configuration ===")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# GPU architecture filtering - only build tests for supported architectures
set(GEMM_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_TEST_GPU_TARGETS ${target})
message(STATUS " Adding GPU target for tests: ${target}")
endif()
endforeach()
# Early exit if no compatible GPU architectures are available
if(NOT GEMM_TEST_GPU_TARGETS)
message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}")
# ============================================================================
# Test Configuration Matrix
# ============================================================================
# Available test configurations (minimal set for fast CI/testing)
set(TEST_CONFIGS
"simple_test_config"
# "medium_tiles_config" # Uncomment for broader testing
)
# Data types for testing (core precision types)
set(TEST_DATATYPES "fp16" "bf16")
# Extended data type options:
# set(TEST_DATATYPES "fp16" "bf16" "fp32" "fp64" "int8")
# Matrix layouts for testing (row-column-row is most common)
set(TEST_LAYOUTS "rcr")
# Extended layout options:
# set(TEST_LAYOUTS "rcr" "rrr" "ccr" "crr")
# ============================================================================
# Test Target Generation Loop
# ============================================================================
foreach(datatype IN LISTS TEST_DATATYPES)
foreach(layout IN LISTS TEST_LAYOUTS)
foreach(config IN LISTS TEST_CONFIGS)
set(CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config}.json")
if(EXISTS ${CONFIG_FILE})
message(STATUS "Building tests for ${datatype}_${layout}_${config}")
build_gemm_test_targets("${datatype}" "${layout}" "${config}")
else()
message(WARNING "Config file not found: ${CONFIG_FILE}")
endif()
endforeach()
endforeach()
endforeach()
message(STATUS "GEMM tile engine tests configured for ${TEST_DATATYPES} with ${TEST_LAYOUTS} layouts using ${TEST_CONFIGS} configurations")

View File

@@ -0,0 +1,27 @@
# GEMM Tile Engine Unit Tests
## How It Works
This unit test system integrates **tile_engine's kernel generation** into automated testing:
1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels
2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine)
3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers
4. **Individual test executables**: Each kernel configuration becomes a separate test
5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine
## Tile Engine Integration
```
JSON Config → tile_engine Python scripts → Generated Headers → Test Executables
```
- **`--list_kernels`**: Get available kernel configurations from JSON
- **`--gen_single`**: Generate individual kernel header for each configuration
- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations
- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.

View File

@@ -0,0 +1,89 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"values": [
128
]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
64
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false
]
}
}
}

View File

@@ -0,0 +1,223 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Unit tests for tile_engine generated GEMM kernels
// Tests kernel correctness using tile_engine's verification methodology
#include <gtest/gtest.h>
#include <iostream>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "tile_engine/ops/gemm/gemm_common.hpp"
// The kernel header is included via compile command line with -include flag
// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types
// Adaptive error threshold calculation matching tile_engine's implementation
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations (from tile_engine)
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
bool compare_results(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
// Test parameter structure for matrix dimensions and split_k values
struct GemmTestParams
{
int m, n, k, split_k;
};
class GemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams>
{
protected:
void SetUp() override
{
auto params = GetParam();
m_ = params.m;
n_ = params.n;
k_ = params.k;
split_k_ = params.split_k;
// Calculate strides (following tile_engine pattern)
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_a_ = k_;
}
else
{
stride_a_ = m_;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b_ = n_;
}
else
{
stride_b_ = k_;
}
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_c_ = n_;
}
else
{
stride_c_ = m_;
}
}
// Test dimensions
int m_, n_, k_, split_k_;
int stride_a_, stride_b_, stride_c_;
};
TEST_P(GemmTileEngineTest, BasicFunctionality)
{
// Get tensor layouts from generated kernel
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
// Use split_k from test parameters
int split_k = split_k_;
int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a));
int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b));
int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c));
// Create host tensors with proper descriptors
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
// Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine)
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// Allocate GPU device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
// Copy data to device and zero output buffer
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
// Calculate reference result on host for verification
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_result);
// Create GEMM kernel arguments
ck_tile::GemmHostArgs gemm_args(a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
split_k,
m_,
n_,
k_,
stride_a_calc,
stride_b_calc,
stride_c_calc);
// Configure kernel execution for maximum speed (no timing, no debug output)
ck_tile::stream_config stream_config{nullptr, // stream
false, // time_kernel (disable timing for speed)
0, // log_level (disable debug output)
0, // n_warmup
1, // n_repeat
false, // is_gpu_timer (unused when time_kernel=false)
false, // flush_cache
1}; // rotating_count
// Launch the generated kernel (no timing overhead for fastest execution)
try
{
SelectedKernel::launch(gemm_args, stream_config);
// Kernel launched successfully if no exception thrown
}
catch(const std::exception& e)
{
FAIL() << "Kernel launch failed: " << e.what();
}
// Copy result back from device
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
// Verify results using tile_engine's adaptive error thresholds
bool verification_passed = compare_results<ADataType, BDataType, AccDataType, CDataType>(
KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result);
EXPECT_TRUE(verification_passed) << "GEMM result verification failed";
}
TEST_P(GemmTileEngineTest, KernelInfo)
{
// Simple test to verify kernel information is available
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_
<< std::endl;
}
// Define test parameters for GEMM verification
INSTANTIATE_TEST_SUITE_P(GemmVerification,
GemmTileEngineTest,
::testing::Values(GemmTestParams{256, 256, 128, 1},
GemmTestParams{256, 256, 1024, 1},
GemmTestParams{256, 512, 512, 1},
GemmTestParams{512, 256, 512, 1}),
[](const ::testing::TestParamInfo<GemmTestParams>& param_info) {
return std::to_string(param_info.param.m) + "x" +
std::to_string(param_info.param.n) + "x" +
std::to_string(param_info.param.k) + "_splitk" +
std::to_string(param_info.param.split_k);
});