[CK_TILE] Row/Col quant gemm (#2729)

* Add cshuffle epilogue test

* add the poc implementation to the epilogue and tests

* refactor cshuffle epilogue

* WIP: adding tensor/tile usage to scale_tile

* fix usage of tile_elementwise_inout

* add gemm_quant_kernel for generalizing gemm quant kernel

* Add problem specific to different quants, add QuantType to Traits

* Add quant_type to quant_kernel template parameters

* Create aq/bq_block_windows and views depending on QuantType

* Use tile windows as inputs in cshuffle epilogue

* Fix some issues in epilogue

* initial new example code for new general gemm quant kernel test

* Fix issues in kernel

* Add verification check for rowcol Quantmode

* use AccDataType instead of AQ in pipeline

* fix aquant preshuffle

* fix formatting

* some cleanup

* remove gemm_aquant_basic.cpp

* remove gemm_aquant_kernel.hpp

* fix tests for the renamed quant kernel

* fix formatting

* clean example files

* fix some merge conflicts

* fix preshufflequant rename issue

* fix some templates after merging with develop

* fix test preshuffle parameter

* fix formatting

* Unify bquant kernel to the common quant kernel

* remove bquant kernel also from common header

* fix formatting

* clean up commented code

* fix formatting config hpp

* fix merge mistake

* Non-const for movable windows

* fix formatting

* Fix grammar in README

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Remove #include<bit> and clean up example

* fix strides

* Add some descriptions for move_windows

---------

Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
This commit is contained in:
Sami Remes
2025-09-05 02:17:12 +03:00
committed by GitHub
parent 7330ec37ee
commit c6010f2953
23 changed files with 1837 additions and 1331 deletions

View File

@@ -14,6 +14,7 @@ namespace ck_tile {
template <typename ADataType_,
typename AQDataType_,
typename BDataType_,
typename BQDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
@@ -23,12 +24,12 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
@@ -44,6 +45,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
using typename Base::CDataType;
using typename Base::ComputeDataType;
using AQDataType = remove_cvref_t<AQDataType_>;
using BQDataType = remove_cvref_t<BQDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
@@ -63,6 +65,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
using Base::VectorLoadSize;
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
static constexpr auto Scheduler = Scheduler_;
@@ -75,7 +78,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_aquant_problem",
return concat('_', "gemm_quant_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
@@ -94,6 +97,13 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return kPadK ? 1 : GetAlignmentAQ();
}();
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
{
return VectorLoadSize / sizeof(BQDataType);
}
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
};
template <typename ADataType_,
@@ -108,18 +118,19 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
AQDataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
AQDataType_,
BDataType_,
void, // no BQDataType for AQuant
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
template <typename ADataType_,
typename BDataType_,
@@ -132,96 +143,42 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmBQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
using Traits = typename Base::Traits;
using typename Base::ADataType;
using typename Base::BDataType;
using typename Base::CDataType;
using typename Base::ComputeDataType;
using BQDataType = remove_cvref_t<BQDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
using typename Base::ALayout;
using typename Base::BLayout;
using typename Base::CLayout;
static constexpr bool TransposeC = Traits::TransposeC;
using Base::kBlockSize;
using Base::kPadK;
using Base::kPadM;
using Base::kPadN;
using Base::DoubleSmemBuffer;
using Base::VectorLoadSize;
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_bquant_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
"QuantGroupSize",
kQuantGroupSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
{
return VectorLoadSize / sizeof(BQDataType);
}
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
};
using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
void, // no AQDataType for BQuant
BDataType_,
BQDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
false, // no TransposeC
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
template <typename ADataType_,
typename BDataType_,
typename BQDataType_,
typename CDataType_,
typename AccDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename ComputeDataType_ = ADataType_,
bool TransposeC_ = false,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmBQuantPipelineProblem = GemmBQuantPipelineProblemBase<ADataType_,
BDataType_,
BQDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile

View File

@@ -4,9 +4,17 @@
#pragma once
#include "ck_tile/core.hpp"
#include <cstdint>
namespace ck_tile {
enum struct QuantType : std::uint16_t
{
AQuantGrouped = 0,
BQuantGrouped = 1,
RowColQuant = 2
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
@@ -14,19 +22,24 @@ template <bool kPadM_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
typename AQLayout_ = ALayout_>
struct TileGemmAQuantTraits
QuantType QuantType_,
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_>
struct TileGemmQuantTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr QuantType kQuantType = QuantType_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using AQLayout = AQLayout_;
using BQLayout = BQLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
@@ -35,31 +48,4 @@ struct TileGemmAQuantTraits
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool PreshuffleQuant_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
typename BQLayout_ = BLayout_>
struct TileGemmBQuantTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using BQLayout = BQLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
} // namespace ck_tile