mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] B matrix 2D block scale gemm (#3074)
* Refactor quant group size to be configurable for M/N/K, not just K * add some asserts for configurations not implemented * start setting of group size for N dimension * enable 2d for reference quant gemm * WIP: trying to figure out tile dstr and/or indexing for scale matrix * WIP * Fix handling of n dim blocks in tile windows etc * remove commented code and enable all tests again * fix formatting * Add more specialized tile distributions * Enable NWarps replication for bquant tile dstr * fix formatting * fix format * Fix some issues from the merge * fix formatting * one more fix to tile dstr, and revert debug initialization * Remove commented code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * simplify conditions that are needed for tile distributions * only enable the working group sizes in tests * fix formatting * Update tile distribution for 2D bquant * add some documentation and 2d block scale example * fix formatting * Add in Changlog and restructure the quant 2d example * fix CMake * support the change for blockscale 2d * fix the test file --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -26,17 +26,17 @@ template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using QDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr uint32_t QuantGroupSize = std::tuple_element_t<9, Tuple>::value;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using QDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<8, Tuple>;
|
||||
using QuantGroupSize = std::tuple_element_t<9, Tuple>;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
|
||||
// Get the quant-type specific data types from traits
|
||||
using QuantTraits = QuantTypeTraits<QuantType>;
|
||||
|
||||
@@ -31,7 +31,7 @@ struct GemmConfigBase
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
@@ -119,9 +119,9 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
@@ -135,7 +135,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// AQuant uses grouped quantization for A matrix
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
|
||||
const ck_tile::index_t stride_AQ =
|
||||
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{}));
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
if constexpr(Base::GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize);
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
@@ -359,11 +359,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::GemmConfig;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto PreshuffleB = Base::PreshuffleB;
|
||||
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr auto PreshuffleB = Base::PreshuffleB;
|
||||
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
@@ -375,8 +375,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// BQuant uses grouped quantization for B matrix
|
||||
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
// BQuant uses block/grouped quantization for B matrix
|
||||
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN);
|
||||
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
|
||||
const ck_tile::index_t stride_BQ = BQK;
|
||||
|
||||
// Generate test data
|
||||
@@ -384,18 +385,18 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_n(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_n);
|
||||
|
||||
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem bq_bqk_bqn_dev_buf(bq_bqk_bqn.get_element_space_size() *
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
@@ -425,25 +426,27 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_n);
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
{
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data());
|
||||
}
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_n_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
0, // QK_A (not used for BQuant)
|
||||
BQK, // QK_B
|
||||
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
@@ -467,7 +470,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
@@ -614,9 +617,9 @@ class TestCkTileGemmRowColQuant
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
@@ -831,9 +834,9 @@ class TestCkTileGemmTensorQuant
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
|
||||
@@ -20,7 +20,15 @@ using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantT
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = std::integral_constant<unsigned int, 128>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for each quantization type
|
||||
// clang-format off
|
||||
@@ -53,10 +61,38 @@ using AQuantTypes = ::testing::Types<
|
||||
|
||||
// clang-format off
|
||||
using BQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
// 1d cases with grouping only on k axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
|
||||
// 2d cases with grouping also on the n axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -77,6 +113,7 @@ using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using RowColQuantTypes = ::testing::Types<
|
||||
|
||||
Reference in New Issue
Block a user