mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Support A/B Quantization in Blockscale GEMM (#3343)
* Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Implement review suggested changes * Sync with develop * fix pre-commit error * Add unit tests for blockscale AB-Quantization * fix pre-commit error * fix pre-commit error * fix compile error * fix compile error * fix clang-format * fix clang-format * fix enumeration values not handled in switch * rebase file * Add missing enums to data_type_sizeof (#3430) Fixes broken build on gfx942. This was some test code that got merged at the same time. * [CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419) * Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen * Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts --------- Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> * Add build trace diagnostics to CI. (#3432) * generate and visualize build traces for all archs * generate build traces in all cases * fix jenkins logic * fix typo * use more threads for parsing dependency map * add script to parse ninja traces and issue warnings * fix python script syntax and header * fix python syntax one more time * fix python syntax * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Sync with develop * Add unit tests for blockscale AB-Quantization * fix enumeration values not handled in switch * rebase file * rebase file --------- Co-authored-by: John Shumway <jshumway@amd.com> Co-authored-by: DarylHawkinsAMD <Daryl.Hawkins@amd.com> Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
6
test/ck_tile/gemm_block_scale/CMakeLists.txt
Executable file → Normal file
6
test/ck_tile/gemm_block_scale/CMakeLists.txt
Executable file → Normal file
@@ -25,6 +25,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr
|
||||
test_gemm_quant_aquant_base_ccr.cpp
|
||||
)
|
||||
# ABQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant
|
||||
test_gemm_quant_abquant.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
|
||||
|
||||
55
test/ck_tile/gemm_block_scale/test_gemm_quant_abquant.cpp
Normal file
55
test/ck_tile/gemm_block_scale/test_gemm_quant_abquant.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -21,6 +21,24 @@
|
||||
template <ck_tile::QuantType QT>
|
||||
struct QuantTypeTraits;
|
||||
|
||||
template <typename TTuple, size_t Index, typename DefaultType, typename Enable = void>
|
||||
struct SafeTupleElement
|
||||
{
|
||||
using type = DefaultType;
|
||||
};
|
||||
|
||||
template <typename TTuple, size_t Index, typename DefaultType>
|
||||
struct SafeTupleElement<TTuple,
|
||||
Index,
|
||||
DefaultType,
|
||||
std::enable_if_t<(Index < std::tuple_size_v<TTuple>)>>
|
||||
{
|
||||
using type = std::tuple_element_t<Index, TTuple>;
|
||||
};
|
||||
|
||||
template <typename TTuple, size_t Index, typename DefaultType>
|
||||
using SafeTupleElement_t = typename SafeTupleElement<TTuple, Index, DefaultType>::type;
|
||||
|
||||
// Base class for common quant gemm functionality
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
@@ -37,6 +55,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
static constexpr auto QuantType = std::tuple_element_t<8, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<9, Tuple>;
|
||||
using QuantGroupSize = std::tuple_element_t<10, Tuple>;
|
||||
using AQuantGroupSize = QuantGroupSize;
|
||||
using BQuantGroupSize = SafeTupleElement_t<Tuple, 11, QuantGroupSize>;
|
||||
using BQLayout = SafeTupleElement_t<Tuple, 12, AQLayout>;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
|
||||
// Get the quant-type specific data types from traits
|
||||
@@ -86,9 +107,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
// Re-use the AQLayout for BQLayout
|
||||
using BQLayout = AQLayout;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
@@ -155,7 +173,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
template <ck_tile::QuantType QT>
|
||||
struct QuantTypeTraits
|
||||
{
|
||||
static_assert(QT == ck_tile::QuantType::AQuantGrouped ||
|
||||
static_assert(QT == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QT == ck_tile::QuantType::AQuantGrouped ||
|
||||
QT == ck_tile::QuantType::BQuantGrouped ||
|
||||
QT == ck_tile::QuantType::RowColQuant ||
|
||||
QT == ck_tile::QuantType::TensorQuant,
|
||||
@@ -182,6 +201,16 @@ struct QuantTypeTraits<ck_tile::QuantType::BQuantGrouped>
|
||||
static constexpr const char* name = "bquant";
|
||||
};
|
||||
|
||||
// Specialization for ABQuantGrouped
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::ABQuantGrouped>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
|
||||
|
||||
static constexpr const char* name = "abquant";
|
||||
};
|
||||
|
||||
// Specialization for RowColQuant
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::RowColQuant>
|
||||
|
||||
@@ -664,6 +664,314 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
}
|
||||
};
|
||||
|
||||
// ABQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmABQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmABQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::AQLayout;
|
||||
using typename Base::AQuantGroupSize;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::BQuantGroupSize;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::GemmConfig;
|
||||
using typename Base::QDataType;
|
||||
using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr auto PreshuffleB = Base::PreshuffleB;
|
||||
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
|
||||
const ck_tile::index_t stride_B =
|
||||
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{}));
|
||||
const ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));
|
||||
|
||||
// AQuant uses grouped quantization for A matrix
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, AQuantGroupSize::kK);
|
||||
// BQuant uses block/grouped quantization for B matrix
|
||||
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, BQuantGroupSize::kN);
|
||||
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, BQuantGroupSize::kK);
|
||||
const ck_tile::index_t stride_AQ =
|
||||
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{}));
|
||||
const ck_tile::index_t stride_BQ =
|
||||
ck_tile::get_default_stride(BQK, BQN, 0, this->is_row_major(BQLayout{}));
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
// AQLayout is independently specified for each test case
|
||||
ck_tile::HostTensor<QDataType> aq_m_aqk( // AQDataType
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f}(a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(aq_m_aqk);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(bq_bqk_bqn);
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() *
|
||||
sizeof(QDataType)); // AQDataType
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem bq_bqk_bqn_dev_buf(bq_bqk_bqn.get_element_space_size() *
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> temp = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(temp);
|
||||
a_m_k_dev_buf.ToDevice(temp.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
if constexpr(TiledMMAPermuteN && BQuantGroupSize::kN == 1)
|
||||
{
|
||||
printf("PreshuffleB with TiledMMAPermuteN\n");
|
||||
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("PreshuffleB without TiledMMAPermuteN\n");
|
||||
b_k_n_dev = ck_tile::shuffle_b<GemmConfig>(b_k_n);
|
||||
}
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
}
|
||||
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
|
||||
if constexpr(Base::GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / AQuantGroupSize::kK);
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
}
|
||||
if constexpr(PreshuffleB && TiledMMAPermuteN && BQuantGroupSize::kN == 1)
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, BQuantGroupSize::kN);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data());
|
||||
}
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
|
||||
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
AQK, // QK_A
|
||||
BQK, // QK_B
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
stride_AQ,
|
||||
stride_BQ // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference ABQuant implementation
|
||||
ck_tile::reference_gemm_abquant<ADataType,
|
||||
QDataType, // AQDataType
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize>(
|
||||
a_m_k, aq_m_aqk, b_k_n, bq_bqk_bqn, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "ABQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "ABQuantGrouped - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// ABQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
std::conditional_t<PreshuffleB == false,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
QDataType, // AQDataType
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
std::conditional_t<PreshuffleB == false,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledMMAPermuteN>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::ABQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for ABQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmPreshuffleBBQuant : public TestCkTileGemmBQuant<Tuple>
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user