Add a new gemm pipeline based on ComputeV4 which utilizes async copy API (#2949)

* check in pipeline and policy

for async  load in mi350, need to make sure TileAccessPattern is warp_raked or block_raked

solve merge conflicts

* fix cmakelists

* make it build

* fix? buffer async fence

* relax fences; it appears it only is needed between pairs of ping-pongs

* remove fences

* remove fences

* cleanup and reformat

* add steps annotations

* comment all pipeline steps / remove unexplainable syncs

* clang-format

* add comment

* cleanup kernel types for test

* fix comment

* fix hardcoded warp size

* faithfully copy block gemm from compute v4 policy to async policy

* make async test gfx950 only

* fix cmake logic

* set separate compile options for async

* refine comment in policy

* try update hotloop scheduler

* cleanup comments

* test more K block sizes

* unhardcode Ks, sort of

* add large odd test case

* fix build for quant

* add comment to hot loop scheduler and rename enum

* reformat

* reword the pipeline description

* reformat

* address review / add static asserts / typo fix

* update changelog
This commit is contained in:
Max Podkorytov
2025-10-01 15:38:07 -07:00
committed by GitHub
parent f2d367262f
commit a7da3c68b9
13 changed files with 803 additions and 62 deletions

View File

@@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.0.0
### Added
* Added a compute async pipeline in the CK TILE universal GEMM on gfx950
* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM.
* Added the new api to load different memory sizes to SGPR.
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.

View File

@@ -275,4 +275,20 @@ CK_TILE_DEVICE static constexpr auto get_device_arch()
return gfx12_t{};
#endif
}
enum LLVMSchedGroupMask : int32_t
{
NONE = 0,
ALU = 1 << 0,
VALU = 1 << 1,
SALU = 1 << 2,
MFMA = 1 << 3,
VMEM = 1 << 4,
VMEM_READ = 1 << 5,
VMEM_WRITE = 1 << 6,
DS = 1 << 7,
DS_READ = 1 << 8,
DS_WRITE = 1 << 9,
ALL = (DS_WRITE << 1) - 1,
};
} // namespace ck_tile

View File

@@ -40,6 +40,8 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"

View File

@@ -0,0 +1,531 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompAsync
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
}
else
{
return TailNumber::Two;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Handle all the valid cases.
if(has_hot_loop)
{
if(tail_number == TailNumber::Three)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
}
else
{
if(tail_number == TailNumber::Three)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
__builtin_unreachable();
#else
throw std::logic_error(
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
#endif
}
};
/**
* @brief Compute optimized pipeline version async; which is based on V4.
*
* This pipeline introduces asynchronous load from global memory to LDS,
* skipping the intermediate loading into pipeline registers.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompAsync<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = get_warp_size();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
// TODO: this will likely need to be redesigned after (1) changes to reading from
// LDS and (2) re-profiling
ignore = i;
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
// TODO support multi-ABD
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
// TODO currently fused elementwise are not supported
ignore = a_element_func;
ignore = b_element_func;
static_assert(std::is_same_v<remove_cvref_t<decltype(a_element_func)>,
element_wise::PassThrough>);
static_assert(std::is_same_v<remove_cvref_t<decltype(b_element_func)>,
element_wise::PassThrough>);
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
// TODO currently only support A matrix row major, B matrix col major; if A matrix is
// col major or B is row major, need to combine with transpose load api
static_assert(!(is_a_col_major || is_b_row_major),
"only support A matrix is row major, B matrix is col major!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window(s) for load
auto a_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// B DRAM window(s) for load
auto b_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
// this pipeline has a pair of LDS buffers per logical tile
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
// LDS tile windows for storing, one per LDS buffer
auto a_copy_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_copy_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// initialize DRAM window steps, used to advance the DRAM windows
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// read A(0), B(0) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// initialize block gemm
auto block_gemm = BlockGemm();
// initialize C block tile
auto c_block_tile = block_gemm.MakeCBlockTile();
clear_tile(c_block_tile);
// read A(1), B(1) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
ALdsTile a_block_tile0;
ALdsTile a_block_tile1;
BLdsTile b_block_tile0;
BLdsTile b_block_tile1;
// LDS tile windows for reading;
// they share the data pointer with the LDS windows for storing
// but also associate with a distribution to produce a register tile when reading
auto a_lds_ld_window0 =
make_tile_window(a_lds_block0,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto a_lds_ld_window1 =
make_tile_window(a_lds_block1,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
auto b_lds_ld_window0 =
make_tile_window(b_lds_block0,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
auto b_lds_ld_window1 =
make_tile_window(b_lds_block1,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window0)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
"LDS windows must not be linear");
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(0), B(0) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
// LDS window(0) contents are overwritten below by global prefetch, need to sync
block_sync_lds();
// read A(2), B(2) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
if(HasHotLoop)
{
// we have had 3 global prefetches so far, indexed (0, 1, 2).
index_t i_global_read = amd_wave_read_first_lane(3);
// alternate ping: (read to register tile(1), use register tile(0) as gemm input)
// pong: (read to register tile(0), use register tile(1) as gemm input)
do
{
// ping
{
// read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
// LDS window(1) contents are overwritten by global prefetch, need to sync
block_sync_lds();
// read A(i), B(i) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window1,
a_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window1,
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
}
// pong
{
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
// read A(i), B(i) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
// LDS window(0) contents are overwritten by global prefetch, need to sync
block_sync_lds();
// read A(i+1), B(i+1) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window0,
a_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window0,
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-2) = A(i-2) @ B(i-2)
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
}
i_global_read += 2;
} while(i_global_read < num_loop);
}
// 3 block gemms remaining
if constexpr(TailNum == TailNumber::Three)
{
{
// read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
{
// read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
{
// C(num_loop) = A(num_loop) @ B(num_loop)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
}
else
// 2 block gemms remaining
{
{
// read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
{
// C(num_loop) = A(num_loop) @ B(num_loop)
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
}
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
}
public:
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem_0,
p_smem_1);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,101 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAgBgCrCompAsync
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy
struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
{
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -9,6 +9,26 @@
namespace ck_tile {
template <typename T, typename = void>
struct has_a_tile_access_pattern : std::false_type
{
};
template <typename T>
struct has_a_tile_access_pattern<T, std::void_t<decltype(T::ATileAccessPattern)>> : std::true_type
{
};
template <typename T, typename = void>
struct has_b_tile_access_pattern : std::false_type
{
};
template <typename T>
struct has_b_tile_access_pattern<T, std::void_t<decltype(T::BTileAccessPattern)>> : std::true_type
{
};
template <typename Derived>
struct UniversalGemmBasePolicy
{
@@ -30,8 +50,25 @@ struct UniversalGemmBasePolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked;
// Default tile access patterns
static constexpr auto DefaultATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto DefaultBTileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto getATileAccessPattern()
{
if constexpr(has_a_tile_access_pattern<Derived>::value)
return Derived::ATileAccessPattern;
else
return DefaultATileAccessPattern;
}
static constexpr auto getBTileAccessPattern()
{
if constexpr(has_b_tile_access_pattern<Derived>::value)
return Derived::BTileAccessPattern;
else
return DefaultBTileAccessPattern;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
@@ -168,11 +205,12 @@ struct UniversalGemmBasePolicy
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern()>;
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
@@ -500,23 +538,25 @@ struct UniversalGemmBasePolicy
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: KPerBlock X MPerBlock
else
{
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern,
NumWaveGroups>;
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
@@ -536,23 +576,25 @@ struct UniversalGemmBasePolicy
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
BTileAccessPattern,
NumWaveGroups>;
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
@@ -573,7 +615,7 @@ struct UniversalGemmBasePolicy
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
}
@@ -594,7 +636,7 @@ struct UniversalGemmBasePolicy
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
}

View File

@@ -15,9 +15,6 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using Base::I1;
using Base::I2;
using Base::ATileAccessPattern;
using Base::BTileAccessPattern;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ()
{

View File

@@ -15,9 +15,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using Base::I1;
using Base::I2;
using Base::ATileAccessPattern;
using Base::BTileAccessPattern;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
{

View File

@@ -11,6 +11,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
-mllvm
-enable-noalias-to-md-conversion=0
)
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
@@ -60,6 +61,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async test_gemm_pipeline_comp_async.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_comp_async PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx11|gfx12")
# On Radeon devices, build the WMMA version instead
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)

View File

@@ -0,0 +1,17 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompAsync
: public TestCkTileGemmPipeline<T, class TestCkTileGemmPipelineCompAsync<T>>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompAsync
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompAsync, KernelTypesCompAsync);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -26,9 +26,10 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using CompAsync = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompAsync>;
using Persistent = std::true_type;
using NonPersistent = std::false_type;
@@ -129,6 +130,10 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
>;
using KernelTypesCompAsync = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>
>;
using KernelTypesCompV4Wmma = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,

View File

@@ -10,18 +10,25 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024;
constexpr int K = 320;
std::vector<int> Ks;
for (auto K_count: {2, 3, 4, 10, 11})
{
Ks.push_back(K_count * TestFixture::K_Tile);
}
for(int M : Ms)
{
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
for(int K : Ks)
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
}
}
}
}
@@ -30,7 +37,12 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
std::vector<int> Ks;
for (auto K_count: {2, 3, 4, 10, 11})
{
Ks.push_back(K_count * TestFixture::K_Tile);
}
constexpr int VecLoadSize = (std::is_same_v<typename TestFixture::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TestFixture::ADataType, ck_tile::bf8_t> ||
std::is_same_v<typename TestFixture::ADataType, ck_tile::int8_t>)
@@ -39,22 +51,25 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
for(int M : Ms)
{
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
for (int K: Ks)
{
if(M % VecLoadSize == 0)
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
this->Run(M, N, K);
if(M % VecLoadSize == 0)
{
this->Run(M, N, K);
}
else
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
}
else
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
this->Run(M, N, K);
}
}
else
{
this->Run(M, N, K);
}
}
}

View File

@@ -37,7 +37,8 @@ enum struct GemmPipelineType
{
Mem,
CompV3,
CompV4
CompV4,
CompAsync
};
template <GemmPipelineType PT, typename Problem>
@@ -70,6 +71,15 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync<Problem>;
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; }
};
template <typename Tuple, typename Derived>
class TestCkTileGemmPipeline : public ::testing::Test
{
@@ -110,7 +120,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadK = PadK;
constexpr bool preshuffle = Preshuffle;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 ||
PipelineType == GemmPipelineType::CompAsync);
// TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false;