mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
feat: Add Interwave scheduler for aquant memory pipeline (#3540)
* WIP: host level interwave pipeline compiles * WIP: interwave implementation computes correct GEMM result when no aquant * WIP: quantization works for subset of problem shapes * WIP: quantization works for subset of problem shapes * WIP: interwave memory pipeline passes local test * feat: Add interwave pipeline implementation for memory pipline in aquant * test: add unit test for aquant memory pipeline * WIP: host level interwave pipeline compiles * WIP: interwave implementation computes correct GEMM result when no aquant * WIP: quantization works for subset of problem shapes * WIP: quantization works for subset of problem shapes * WIP: interwave memory pipeline passes local test * feat: Add interwave pipeline implementation for memory pipline in aquant * fix: compilation error on gfx950 * chore: remove debug statements from the code * test: resolve merge conflict * test: remove non rcr unit tests from test suite
This commit is contained in:
@@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// for every column in AQ
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
// for every warp corresponding to a quantization scale
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
@@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
|
||||
{
|
||||
static constexpr index_t KPerThread = GemmTraits::KPerThread;
|
||||
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
|
||||
|
||||
static constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
static constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <index_t KIdx,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
constexpr auto a_lds_load_distr = [&]() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeABlockDistributionEncode()),
|
||||
ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto b_lds_load_distr = [&]() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeBBlockDistributionEncode()),
|
||||
BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
|
||||
constexpr auto a_offset =
|
||||
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
constexpr auto b_offset =
|
||||
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_lds_gemm_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
// C += A * B with quantization support
|
||||
template <typename CBlockTensor,
|
||||
typename AQBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
AQBlockTensor& aq_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as corresponding "
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// Track which KRepeat chunk is currently loaded
|
||||
index_t current_k_repeat_loaded = -1;
|
||||
|
||||
// Restructured loop: M → N → QScale → KIterPerQScale
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Iterate over quantization groups
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Accumulate K iterations for this quantization group
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
// Map quantization indices to global K iteration
|
||||
constexpr auto kIterGlobal =
|
||||
kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
// Map to KRepeat chunk and KInnerLoopIter offset
|
||||
constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter;
|
||||
constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter;
|
||||
|
||||
// Prefetch new chunk if needed
|
||||
if constexpr(kInnerIdx == 0)
|
||||
{
|
||||
if(current_k_repeat_loaded != kRepeatIdx)
|
||||
{
|
||||
LocalPrefetch<kRepeatIdx>(
|
||||
a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(kRepeatIdx != 0 || KRepeat == 1)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
current_k_repeat_loaded = kRepeatIdx;
|
||||
}
|
||||
}
|
||||
|
||||
// Load A warp tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kInnerIdx>{},
|
||||
a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// Load B warp tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kInnerIdx>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// Synchronization barrier at the end of last iteration
|
||||
if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 &&
|
||||
kIterInQScale == Traits::KIterPerQScale - 1 &&
|
||||
mIter.value == MIterPerWarp - 1 &&
|
||||
nIter.value == NIterPerWarp - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Accumulate: first iteration initializes, rest accumulate
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Set priority for scheduling
|
||||
if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
|
||||
// Apply quantization scale after accumulating all K iterations for this
|
||||
// group
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
float scale_reg_f = aq_picker.template pick<c_row>();
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
@@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow,
|
||||
template <index_t KIdx = 0,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
@@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
if constexpr(Scheduler == GemmPipelineScheduler::Interwave)
|
||||
{
|
||||
block_gemm_impl_.template LocalPrefetch<KIdx>(
|
||||
a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -499,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
|
||||
.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const OverrideADataType& a) { return a; },
|
||||
[](const BDataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
|
||||
Reference in New Issue
Block a user