feat: Add Interwave scheduler for aquant memory pipeline (#3540)

* WIP: host level interwave pipeline compiles

* WIP: interwave implementation computes correct GEMM result when no aquant

* WIP: quantization works for subset of problem shapes

* WIP: quantization works for subset of problem shapes

* WIP: interwave memory pipeline passes local test

* feat: Add interwave pipeline implementation for memory pipline in aquant

* test: add unit test for aquant memory pipeline

* WIP: host level interwave pipeline compiles

* WIP: interwave implementation computes correct GEMM result when no aquant

* WIP: quantization works for subset of problem shapes

* WIP: quantization works for subset of problem shapes

* WIP: interwave memory pipeline passes local test

* feat: Add interwave pipeline implementation for memory pipline in aquant

* fix: compilation error on gfx950

* chore: remove debug statements from the code

* test: resolve merge conflict

* test: remove non rcr unit tests from test suite
This commit is contained in:
Aviral Goel
2026-01-27 00:57:42 +05:30
committed by GitHub
parent 3900e1e7ce
commit b8751e505d
11 changed files with 829 additions and 9 deletions

View File

@@ -11,7 +11,24 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time
# AQuant tests - split into 6 files
# AQuant tests - split into 10 files
# AQuant Memory Pipeline tests
add_gtest_executable(test_tile_gemm_quant_aquant_mem_prefill_interwave
test_gemm_quant_aquant_mem_prefill_interwave.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_mem_prefill_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_intrawave
test_gemm_quant_aquant_mem_decode_intrawave.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_mem_decode_intrawave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_interwave
test_gemm_quant_aquant_mem_decode_interwave.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_mem_decode_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_base_rcr
test_gemm_quant_aquant_base_rcr.cpp
)
@@ -150,10 +167,21 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# Target to build only AQuant memory pipeline tests
add_custom_target(test_tile_gemm_aquant_mem_all)
add_dependencies(test_tile_gemm_aquant_mem_all
test_tile_gemm_quant_aquant_mem_prefill_interwave
test_tile_gemm_quant_aquant_mem_decode_intrawave
test_tile_gemm_quant_aquant_mem_decode_interwave
)
# Umbrella target to build all gemm quant tests
add_custom_target(test_tile_gemm_quant_all)
add_dependencies(test_tile_gemm_quant_all
# AQuant tests
test_tile_gemm_quant_aquant_mem_prefill_interwave
test_tile_gemm_quant_aquant_mem_decode_intrawave
test_tile_gemm_quant_aquant_mem_decode_interwave
test_tile_gemm_quant_aquant_base_rcr
test_tile_gemm_quant_aquant_base_rrr_crr
test_tile_gemm_quant_aquant_base_ccr

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - Mem Decode Interwave Configuration
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantMemDecodeInterwaveTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigDecodeInterwave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigDecodeInterwave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigDecodeInterwave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigDecodeInterwave, GroupSize>
>;
// clang-format on
// Test suite for AQuant Mem Decode Interwave
TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTest)
{
this->run_test_with_validation(16, 64, 512);
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - Mem Decode Intrawave Configuration
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantMemDecodeIntrawaveTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigDecodeIntrawave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigDecodeIntrawave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigDecodeIntrawave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigDecodeIntrawave, GroupSize>
>;
// clang-format on
// Test suite for AQuant Mem Decode Intrawave
TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTest)
{
this->run_test_with_validation(16, 64, 512);
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for AQuant tests - Mem Prefill Interwave Configuration
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using AQuantMemPrefillInterwaveTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefillInterwave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefillInterwave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefillInterwave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefillInterwave, GroupSize>
>;
// clang-format on
// Test suite for AQuant Mem Prefill Interwave
TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// clang-format off
using AQuantPrefillTypes = ::testing::Types<
// RCR layout - with the Prefill BlockTile Config.
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefillIntrawave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefillIntrawave, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefillIntrawave, GroupSize>
>;
// clang-format on

View File

@@ -69,6 +69,38 @@ struct GemmConfigPrefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
};
struct GemmConfigPrefillIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigPrefillInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigDecodeIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigDecodeInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigMxFp4 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
@@ -374,6 +406,223 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
}
};
template <typename Tuple>
class TestCkTileGemmAQuantMem
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuantMem<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuantMem<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::AQLayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
using typename Base::QuantGroupSize;
static constexpr auto QuantType = Base::QuantType;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
// AQuant-specific data generation
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A =
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
const ck_tile::index_t stride_B =
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{}));
const ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));
// AQuant uses grouped quantization for A matrix
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
// AQLayout is parameterized in the test tuple (can be RowMajor or ColumnMajor for AQuant)
const ck_tile::index_t stride_AQ =
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{}));
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
// AQLayout is independently specified for each test case
ck_tile::HostTensor<QDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
// Initialize data with random values
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f}(a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f}(a_m_k);
}
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(aq_m_aqk);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> temp = a_m_k;
ck_tile::permute_vectors_i4x4_b(temp);
a_m_k_dev_buf.ToDevice(temp.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
}
else
{
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
nullptr, // bq_ptr (not used for AQuant)
1, // k_batch
M,
N,
K, // M, N, K
AQK, // QK_A
0, // QK_B (not used for AQuant)
stride_A,
stride_B,
stride_C,
stride_AQ,
0 // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference AQuant implementation
ck_tile::reference_gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "AQuantGrouped - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// AQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
using PipelineProblem = ck_tile::GemmAQuantPipelineProblem<ADataType,
QDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
Base::GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for AQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
// BQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>