mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Support A/B Quantization in Blockscale GEMM (#3343)
* Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Implement review suggested changes * Sync with develop * fix pre-commit error * Add unit tests for blockscale AB-Quantization * fix pre-commit error * fix pre-commit error * fix compile error * fix compile error * fix clang-format * fix clang-format * fix enumeration values not handled in switch * rebase file * Add missing enums to data_type_sizeof (#3430) Fixes broken build on gfx942. This was some test code that got merged at the same time. * [CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419) * Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen * Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts --------- Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> * Add build trace diagnostics to CI. (#3432) * generate and visualize build traces for all archs * generate build traces in all cases * fix jenkins logic * fix typo * use more threads for parsing dependency map * add script to parse ninja traces and issue warnings * fix python script syntax and header * fix python syntax one more time * fix python syntax * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Sync with develop * Add unit tests for blockscale AB-Quantization * fix enumeration values not handled in switch * rebase file * rebase file --------- Co-authored-by: John Shumway <jshumway@amd.com> Co-authored-by: DarylHawkinsAMD <Daryl.Hawkins@amd.com> Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmABQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using AQuantBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using BQuantBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
template <typename AQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
|
||||
{
|
||||
return AQuantBase{}.GetAQDramLoadWindow(aq_dram_block_window_tmp);
|
||||
}
|
||||
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
|
||||
{
|
||||
return BQuantBase{}.GetBQDramLoadWindow(bq_dram_block_window_tmp);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "gemm_group_quant_utils.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using Base = UniversalGemmPipelineAgBgCrPolicy;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ()
|
||||
{
|
||||
return GemmAQuantPipelineAgBgCrDefaultPolicy::GetVectorSizeAQ<Problem>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution()
|
||||
{
|
||||
return GemmAQuantPipelineAgBgCrDefaultPolicy::MakeAQDramTileDistribution<Problem>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
|
||||
{
|
||||
return GemmBQuantPipelineAgBgCrDefaultPolicy::GetVectorSizeBQ<Problem>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
|
||||
{
|
||||
return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<typename Problem::CDataType, float>);
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return ABQuantBlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,604 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <typename Problem, typename Policy = GemmABQuantPipelineAgBgCrDefaultPolicy>
|
||||
struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
static_assert(AQuantGroupSize::kM == 1, "no block M for AQuant kernel supported yet!");
|
||||
static_assert(AQuantGroupSize::kK == BQuantGroupSize::kK,
|
||||
"AQuantGroupSize::kK should be equal to BQuantGroupSize::kK");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t AQPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<AQDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BQPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeAQ()
|
||||
{
|
||||
return Policy::template GetVectorSizeAQ<Problem>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
return Policy::template GetVectorSizeBQ<Problem>();
|
||||
}
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
return concat('_', "abquant_pipeline_AgBgCrCompV3",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock),
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
constexpr index_t AQ_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ());
|
||||
constexpr index_t BQ_Buffer_Load_Inst_Num =
|
||||
NPerBlockBQ * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
auto str = std::stringstream{};
|
||||
|
||||
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
|
||||
<< "AQ vector size: " << GetVectorSizeAQ() << "\n"
|
||||
<< "BQ vector size: " << GetVectorSizeBQ() << "\n"
|
||||
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
|
||||
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
|
||||
<< ", " << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n"
|
||||
<< ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n"
|
||||
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
|
||||
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename ADramWindow, typename ABlockTile_>
|
||||
CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile,
|
||||
const ADramWindow& a_dram_window)
|
||||
{
|
||||
using DestDataType = typename ABlockTile_::DataType;
|
||||
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
}
|
||||
|
||||
template <typename BDramWindow, typename BBlockTile_>
|
||||
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
|
||||
const BDramWindow& b_dram_window)
|
||||
{
|
||||
using DestDataType = typename BBlockTile_::DataType;
|
||||
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AQDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t m,
|
||||
index_t n,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<AQDataType,
|
||||
remove_cvref_t<typename AQDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BQDataType,
|
||||
remove_cvref_t<typename BQDramBlockWindowTmp::DataType>>,
|
||||
"A/B/AQ/BQ Dram block window should have the same data type as appropriate "
|
||||
"([A|B|AQ|BQ]DataType) defined in Problem definition!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_aq_col_major =
|
||||
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
constexpr bool is_bq_row_major =
|
||||
std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
static_assert(
|
||||
PreshuffleQuant ||
|
||||
(is_bq_row_major
|
||||
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])),
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
|
||||
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
// Note: BDataType PkInt4 gets converted during loading, before going to LDS
|
||||
auto&& [a_lds_block, b_lds_block] =
|
||||
Base::template GetABLdsTensorViews<ADataType, OverrideBDataType>(p_smem);
|
||||
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp);
|
||||
auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp);
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
|
||||
using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
using AQBlockTile =
|
||||
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
|
||||
using BQBlockTile =
|
||||
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
|
||||
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
AQBlockTile aq_block_tile[2];
|
||||
BQBlockTile bq_block_tile[2];
|
||||
int currIdx = 0;
|
||||
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
// only row_major for AQ
|
||||
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
PreshuffleQuant
|
||||
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
0)
|
||||
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
// Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
// Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
// B tile gets converted to A datatype during loading
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
Base::GlobalPrefetch(
|
||||
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
// Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
// Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
constexpr index_t tail_count =
|
||||
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
// Note: BDataType PkInt4 gets converted during loading earlier
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
// Base::GlobalPrefetch(a_block_tile, a_copy_dram_window,
|
||||
// a_dram_tile_window_step);
|
||||
// Base::GlobalPrefetch(b_block_tile, b_copy_dram_window,
|
||||
// b_dram_tile_window_step);
|
||||
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
|
||||
aq_copy_dram_window,
|
||||
aq_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
|
||||
bq_copy_dram_window,
|
||||
bq_dram_tile_window_step);
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
aq_block_tile[currIdx],
|
||||
bq_block_tile[currIdx],
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window);
|
||||
|
||||
currIdx = (currIdx + 1) % 2;
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - tail_count));
|
||||
}
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
{
|
||||
block_gemm(c_block_tile,
|
||||
aq_block_tile[currIdx],
|
||||
bq_block_tile[currIdx],
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
|
||||
aq_copy_dram_window,
|
||||
aq_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
|
||||
bq_copy_dram_window,
|
||||
bq_dram_tile_window_step);
|
||||
block_gemm(c_block_tile,
|
||||
aq_block_tile[currIdx],
|
||||
bq_block_tile[currIdx],
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
currIdx = (currIdx + 1) % 2;
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
// Note: BDataType gets converted during loading from PkInt4
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile,
|
||||
aq_block_tile[currIdx],
|
||||
bq_block_tile[currIdx],
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
// Overload for PreshuffleQuant = true
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AQDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem,
|
||||
index_t m = 0,
|
||||
index_t n = 0) const
|
||||
{
|
||||
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
bq_dram_block_window_tmp,
|
||||
m,
|
||||
n,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
|
||||
///
|
||||
/// This operator is used by grouped GEMM kernels where pipeline parameters
|
||||
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
|
||||
/// at runtime, not on the host side during compilation. This is necessary
|
||||
/// because different GEMM problems in the group may have different K dimensions,
|
||||
/// requiring different pipeline configurations that cannot be determined at
|
||||
/// compile time.
|
||||
///
|
||||
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
|
||||
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
|
||||
/// @param aq_dram_block_window_tmp Block window for AQ (quantization scale) tensor in DRAM
|
||||
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
|
||||
/// @param num_loop Number of main loop iterations (calculated on device)
|
||||
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
|
||||
/// @param tail_number Type of tail handling required (calculated on device)
|
||||
/// @param p_smem Pointer to shared memory
|
||||
/// @return Accumulated result tile in registers
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AQDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem,
|
||||
index_t m = 0,
|
||||
index_t n = 0) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
|
||||
constexpr bool hot_loop = has_hot_loop_.value;
|
||||
constexpr auto tail_num = tail_number_.value;
|
||||
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
// Note: BDataType PkInt4 gets converted during loading
|
||||
[](const OverrideBDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
bq_dram_block_window_tmp,
|
||||
m,
|
||||
n, // dummy value, won't be used
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -18,7 +18,7 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -28,7 +28,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
@@ -21,7 +21,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
|
||||
|
||||
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
|
||||
}
|
||||
@@ -35,7 +35,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
@@ -116,7 +116,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
static_assert(Problem::AQuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
|
||||
@@ -25,7 +25,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
@@ -18,7 +18,7 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
|
||||
// Support both RowMajor and ColumnMajor layouts for BQ
|
||||
if constexpr(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -45,9 +45,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
@@ -82,7 +82,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
BlockSize,
|
||||
KPerBlockBQ, // Logical K dimension
|
||||
NPerBlockBQ, // Logical N dimension
|
||||
Problem::QuantGroupSize::kN,
|
||||
Problem::BQuantGroupSize::kN,
|
||||
BQLayout>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
@@ -95,7 +95,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
|
||||
@@ -31,7 +31,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -18,7 +18,8 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename QuantGroupSize_,
|
||||
typename AQuantGroupSize_,
|
||||
typename BQuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
@@ -48,7 +49,12 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = QuantGroupSize_;
|
||||
using AQuantGroupSize =
|
||||
std::conditional_t<!std::is_void_v<AQuantGroupSize_>, AQuantGroupSize_, BQuantGroupSize_>;
|
||||
using BQuantGroupSize =
|
||||
std::conditional_t<!std::is_void_v<BQuantGroupSize_>, BQuantGroupSize_, AQuantGroupSize_>;
|
||||
// Unified alias for 1D quantization usage, to avoid forcing users to pick one.
|
||||
using QuantGroupSize = BQuantGroupSize;
|
||||
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BLayout;
|
||||
@@ -72,9 +78,12 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kM % QuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % QuantGroupSize::kK == 0);
|
||||
static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % AQuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0);
|
||||
static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % BQuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % BQuantGroupSize::kK == 0);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -83,7 +92,8 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
QuantGroupSize::GetName());
|
||||
AQuantGroupSize::GetName(),
|
||||
BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -126,6 +136,7 @@ using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
void,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
@@ -150,6 +161,7 @@ using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
void,
|
||||
QuantGroupSize_,
|
||||
false, // no TransposeC
|
||||
ComputeDataType_,
|
||||
@@ -157,6 +169,35 @@ using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename AQDataType_,
|
||||
typename BDataType_,
|
||||
typename BQDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename AQuantGroupSize_,
|
||||
typename BQuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmABQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
AQDataType_,
|
||||
BDataType_,
|
||||
BQDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
AQuantGroupSize_,
|
||||
BQuantGroupSize_,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
@@ -176,6 +217,7 @@ using GemmRowColTensorQuantPipelineProblem =
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
void,
|
||||
QuantGroupShape<sequence<1, 1, 1>>, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
|
||||
@@ -16,9 +16,9 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
|
||||
{
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
@@ -10,10 +10,11 @@ namespace ck_tile {
|
||||
|
||||
enum struct QuantType : std::uint16_t
|
||||
{
|
||||
AQuantGrouped = 0,
|
||||
BQuantGrouped = 1,
|
||||
RowColQuant = 2,
|
||||
TensorQuant = 3
|
||||
AQuantGrouped = 0,
|
||||
BQuantGrouped = 1,
|
||||
RowColQuant = 2,
|
||||
TensorQuant = 3,
|
||||
ABQuantGrouped = 4
|
||||
};
|
||||
|
||||
inline std::string quant_type_to_string(QuantType quant_type)
|
||||
@@ -24,6 +25,7 @@ inline std::string quant_type_to_string(QuantType quant_type)
|
||||
case QuantType::BQuantGrouped: return "BQuantGrouped";
|
||||
case QuantType::RowColQuant: return "RowColQuant";
|
||||
case QuantType::TensorQuant: return "TensorQuant";
|
||||
case QuantType::ABQuantGrouped: return "ABQuantGrouped";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user