feat: Add Interwave scheduler for aquant memory pipeline (#3540)

* WIP: host level interwave pipeline compiles

* WIP: interwave implementation computes correct GEMM result when no aquant

* WIP: quantization works for subset of problem shapes

* WIP: quantization works for subset of problem shapes

* WIP: interwave memory pipeline passes local test

* feat: Add interwave pipeline implementation for memory pipline in aquant

* test: add unit test for aquant memory pipeline

* WIP: host level interwave pipeline compiles

* WIP: interwave implementation computes correct GEMM result when no aquant

* WIP: quantization works for subset of problem shapes

* WIP: quantization works for subset of problem shapes

* WIP: interwave memory pipeline passes local test

* feat: Add interwave pipeline implementation for memory pipline in aquant

* fix: compilation error on gfx950

* chore: remove debug statements from the code

* test: resolve merge conflict

* test: remove non rcr unit tests from test suite
This commit is contained in:
Aviral Goel
2026-01-27 00:57:42 +05:30
committed by GitHub
parent 3900e1e7ce
commit b8751e505d
11 changed files with 829 additions and 9 deletions

View File

@@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
// for every column in AQ
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
// for every warp corresponding to a quantization scale
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
@@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
{
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
make_static_tile_distribution(MakeABlockDistributionEncode());
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <index_t KIdx,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
constexpr auto a_lds_load_distr = [&]() {
if constexpr(ALoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeABlockDistributionEncode()),
ADataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeABlockDistributionEncode());
}();
constexpr auto b_lds_load_distr = [&]() {
if constexpr(BLoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeBBlockDistributionEncode()),
BDataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeBBlockDistributionEncode());
}();
constexpr auto a_lds_shape = []() {
if constexpr(ALoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
else
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(BLoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
else
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
constexpr auto a_offset =
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
constexpr auto b_offset =
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
auto a_lds_gemm_window = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_lds_gemm_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_lds_gemm_window);
}
// C += A * B with quantization support
template <typename CBlockTensor,
typename AQBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
AQBlockTensor& aq_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// Track which KRepeat chunk is currently loaded
index_t current_k_repeat_loaded = -1;
// Restructured loop: M → N → QScale → KIterPerQScale
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Iterate over quantization groups
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
CWarpTensor c_warp_tensor;
// Accumulate K iterations for this quantization group
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
// Map quantization indices to global K iteration
constexpr auto kIterGlobal =
kQScale * Traits::KIterPerQScale + kIterInQScale;
// Map to KRepeat chunk and KInnerLoopIter offset
constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter;
constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter;
// Prefetch new chunk if needed
if constexpr(kInnerIdx == 0)
{
if(current_k_repeat_loaded != kRepeatIdx)
{
LocalPrefetch<kRepeatIdx>(
a_block_window, b_block_window, a_load_tr, b_load_tr);
__builtin_amdgcn_sched_barrier(0);
if constexpr(kRepeatIdx != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
current_k_repeat_loaded = kRepeatIdx;
}
}
// Load A warp tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kInnerIdx>{},
a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// Load B warp tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kInnerIdx>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// Synchronization barrier at the end of last iteration
if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 &&
kIterInQScale == Traits::KIterPerQScale - 1 &&
mIter.value == MIterPerWarp - 1 &&
nIter.value == NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// Accumulate: first iteration initializes, rest accumulate
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
// Set priority for scheduling
if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
// Apply quantization scale after accumulating all K iterations for this
// group
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
@@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr
MakeCBlockTile();
}
template <typename ASmemBlockWindow,
template <index_t KIdx = 0,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
@@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
if constexpr(Scheduler == GemmPipelineScheduler::Interwave)
{
block_gemm_impl_.template LocalPrefetch<KIdx>(
a_block_window, b_block_window, a_load_tr, b_load_tr);
}
else
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
}
// C += A * B

View File

@@ -499,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const OverrideADataType& a) { return a; },
[](const BDataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,