mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
refactor: remove Default scheduler implementation as it not used anymore (#3542)
* refactor: remove Default scheduler implementation as it not used anymore * refactor: remove dead code from gemm universal kernel * chore: add descriptive comments about amd intrinsic hardware sync instructions * fix: label existing memory pipeline for aquant as intrawave
This commit is contained in:
@@ -194,83 +194,6 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
|
||||
{
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_block_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_block_window);
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
{
|
||||
@@ -450,7 +373,9 @@ struct BlockUniversalGemmAsBsCr
|
||||
// hot loop:
|
||||
static_for<0, KRepeat, 1>{}([&](auto kIter) {
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // Complete scheduling all pending instruction groups before this point
|
||||
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC
|
||||
// cluster, but except the first, as we can shorten non-MAC cluster a bit
|
||||
// and there's no observable negative impact. The desired effect is waves in
|
||||
@@ -460,8 +385,14 @@ struct BlockUniversalGemmAsBsCr
|
||||
// sync point.
|
||||
if constexpr(kIter.value != 0 || KRepeat == 1)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// This pattern ensures:
|
||||
// At runtime: All waves synchronize (hardware barrier)
|
||||
// At compile-time: Instructions after the barrier don't get moved before it
|
||||
// (scheduling barrier)
|
||||
__builtin_amdgcn_s_barrier(); // Blocks execution until all waves (threads) in
|
||||
// the workgroup reach this point
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // Prevents instruction reordering across this boundary
|
||||
}
|
||||
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
|
||||
|
||||
@@ -1035,7 +1035,6 @@ struct UniversalGemmKernel
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
@@ -1161,9 +1160,7 @@ struct UniversalGemmKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
constexpr auto scheduler_type =
|
||||
GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(
|
||||
RunGemm(
|
||||
as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ struct GemmPipelineProblemBase
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
// In the base situation, the Preshuffle setting should be false.
|
||||
|
||||
@@ -164,7 +164,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
@@ -491,7 +491,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
void* p_smem,
|
||||
index_t m = 0) const
|
||||
{
|
||||
return PipelineImpl<GemmPipelineScheduler::Interwave>{}
|
||||
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
|
||||
.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const BDataType& a) { return a; },
|
||||
|
||||
Reference in New Issue
Block a user