mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Row/Col quant gemm (#2729)
* Add cshuffle epilogue test * add the poc implementation to the epilogue and tests * refactor cshuffle epilogue * WIP: adding tensor/tile usage to scale_tile * fix usage of tile_elementwise_inout * add gemm_quant_kernel for generalizing gemm quant kernel * Add problem specific to different quants, add QuantType to Traits * Add quant_type to quant_kernel template parameters * Create aq/bq_block_windows and views depending on QuantType * Use tile windows as inputs in cshuffle epilogue * Fix some issues in epilogue * initial new example code for new general gemm quant kernel test * Fix issues in kernel * Add verification check for rowcol Quantmode * use AccDataType instead of AQ in pipeline * fix aquant preshuffle * fix formatting * some cleanup * remove gemm_aquant_basic.cpp * remove gemm_aquant_kernel.hpp * fix tests for the renamed quant kernel * fix formatting * clean example files * fix some merge conflicts * fix preshufflequant rename issue * fix some templates after merging with develop * fix test preshuffle parameter * fix formatting * Unify bquant kernel to the common quant kernel * remove bquant kernel also from common header * fix formatting * clean up commented code * fix formatting config hpp * fix merge mistake * Non-const for movable windows * fix formatting * Fix grammar in README Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> * Remove #include<bit> and clean up example * fix strides * Add some descriptions for move_windows --------- Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com> Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
@@ -14,6 +14,7 @@ namespace ck_tile {
|
||||
template <typename ADataType_,
|
||||
typename AQDataType_,
|
||||
typename BDataType_,
|
||||
typename BQDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
@@ -23,12 +24,12 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>
|
||||
struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using Base = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
@@ -44,6 +45,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
using typename Base::CDataType;
|
||||
using typename Base::ComputeDataType;
|
||||
using AQDataType = remove_cvref_t<AQDataType_>;
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
|
||||
@@ -63,6 +65,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
using Base::VectorLoadSize;
|
||||
|
||||
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
|
||||
|
||||
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
@@ -75,7 +78,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_aquant_problem",
|
||||
return concat('_', "gemm_quant_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
@@ -94,6 +97,13 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
return kPadK ? 1 : GetAlignmentAQ();
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
|
||||
{
|
||||
return VectorLoadSize / sizeof(BQDataType);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -108,18 +118,19 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
|
||||
AQDataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
AQDataType_,
|
||||
BDataType_,
|
||||
void, // no BQDataType for AQuant
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
@@ -132,96 +143,42 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
struct GemmBQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using Base = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using typename Base::ADataType;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::ComputeDataType;
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CLayout;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
using Base::kBlockSize;
|
||||
|
||||
using Base::kPadK;
|
||||
using Base::kPadM;
|
||||
using Base::kPadN;
|
||||
|
||||
using Base::DoubleSmemBuffer;
|
||||
using Base::VectorLoadSize;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
|
||||
|
||||
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
|
||||
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_bquant_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
"QuantGroupSize",
|
||||
kQuantGroupSize);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
|
||||
{
|
||||
return VectorLoadSize / sizeof(BQDataType);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
|
||||
};
|
||||
using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
void, // no AQDataType for BQuant
|
||||
BDataType_,
|
||||
BQDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
false, // no TransposeC
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename BQDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool TransposeC_ = false,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmBQuantPipelineProblem = GemmBQuantPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
BQDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
AccDataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,9 +4,17 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct QuantType : std::uint16_t
|
||||
{
|
||||
AQuantGrouped = 0,
|
||||
BQuantGrouped = 1,
|
||||
RowColQuant = 2
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
@@ -14,19 +22,24 @@ template <bool kPadM_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
typename AQLayout_ = ALayout_>
|
||||
struct TileGemmAQuantTraits
|
||||
QuantType QuantType_,
|
||||
typename AQLayout_ = ALayout_,
|
||||
typename BQLayout_ = BLayout_>
|
||||
struct TileGemmQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr QuantType kQuantType = QuantType_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AQLayout = AQLayout_;
|
||||
using BQLayout = BQLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
@@ -35,31 +48,4 @@ struct TileGemmAQuantTraits
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool PreshuffleQuant_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
typename BQLayout_ = BLayout_>
|
||||
struct TileGemmBQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using BQLayout = BQLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user