mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] B matrix 2D block scale gemm (#3074)
* Refactor quant group size to be configurable for M/N/K, not just K
* add some asserts for configurations not implemented
* start setting of group size for N dimension
* enable 2d for reference quant gemm
* WIP: trying to figure out tile dstr and/or indexing for scale matrix
* WIP
* Fix handling of n dim blocks in tile windows etc
* remove commented code and enable all tests again
* fix formatting
* Add more specialized tile distributions
* Enable NWarps replication for bquant tile dstr
* fix formatting
* fix format
* Fix some issues from the merge
* fix formatting
* one more fix to tile dstr, and revert debug initialization
* Remove commented code
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* simplify conditions that are needed for tile distributions
* only enable the working group sizes in tests
* fix formatting
* Update tile distribution for 2D bquant
* add some documentation and 2d block scale example
* fix formatting
* Add in Changlog and restructure the quant 2d example
* fix CMake
* support the change for blockscale 2d
* fix the test file
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
[ROCm/composable_kernel commit: 16e85cf179]
This commit is contained in:
@@ -24,6 +24,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added WMMA (gfx12) support for FMHA.
|
||||
* Added pooling kernel in CK_TILE
|
||||
* Added top-k sigmoid kernel in CK_TILE
|
||||
* Added the blockscale 2D support for CK_TILE GEMM.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
@@ -17,7 +22,7 @@ template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise>
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
@@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
// This example only supports BQuant (no AQuant)
|
||||
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
|
||||
// for aquant
|
||||
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -266,6 +272,41 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Forward declaration for dispatch function
|
||||
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
|
||||
int dispatch_by_data_type(const std::string& data_type,
|
||||
const std::string& quant_mode,
|
||||
const std::string& a_layout,
|
||||
const std::string& b_layout,
|
||||
int argc,
|
||||
char* argv[]);
|
||||
|
||||
// Helper function to parse group size string "MxNxK"
|
||||
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
|
||||
{
|
||||
int m = 1, n = 1, k = 128;
|
||||
|
||||
size_t first_x = group_size_str.find('x');
|
||||
if(first_x == std::string::npos)
|
||||
{
|
||||
// Single number provided, assume it's the K dimension
|
||||
k = std::stoi(group_size_str);
|
||||
return {1, 1, k};
|
||||
}
|
||||
|
||||
size_t second_x = group_size_str.find('x', first_x + 1);
|
||||
if(second_x == std::string::npos)
|
||||
{
|
||||
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
|
||||
}
|
||||
|
||||
m = std::stoi(group_size_str.substr(0, first_x));
|
||||
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
|
||||
k = std::stoi(group_size_str.substr(second_x + 1));
|
||||
|
||||
return {m, n, k};
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
@@ -273,139 +314,57 @@ int run_gemm_example(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
std::string group_size_str = arg_parser.get_str("group_size");
|
||||
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
auto [m_group, n_group, k_group] = parse_group_size(group_size_str);
|
||||
|
||||
// Dispatch based on group size (M, N, K)
|
||||
return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
|
||||
using QuantGroupSize = decltype(QGS_);
|
||||
return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
|
||||
data_type, quant_mode, a_layout, b_layout, argc, argv);
|
||||
});
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
|
||||
int dispatch_by_data_type(const std::string& data_type,
|
||||
const std::string& quant_mode,
|
||||
const std::string& a_layout,
|
||||
const std::string& b_layout,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
// This example ONLY supports BQuant for 2D block scale quantization
|
||||
if(quant_mode != "bquant")
|
||||
{
|
||||
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
|
||||
}
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
@@ -414,19 +373,11 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
|
||||
if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'bquant'.");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
@@ -435,19 +386,11 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
|
||||
if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'bquant'.");
|
||||
}
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -455,7 +398,27 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename> typename GemmConfig, typename F>
|
||||
int dispatch_group_size_ct(int m, int n, int k, F&& f)
|
||||
{
|
||||
// This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
|
||||
#define DISPATCH_ONE(M, N, K) \
|
||||
if(m == M && n == N && k == K) \
|
||||
{ \
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
|
||||
return f(QuantGroupSize{}); \
|
||||
}
|
||||
|
||||
CK_TILE_SUPPORTED_QUANT_GROUPS(DISPATCH_ONE)
|
||||
|
||||
#undef DISPATCH_ONE
|
||||
|
||||
throw std::runtime_error(
|
||||
"Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X).");
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
|
||||
// Use non-preshuffled GemmConfig for 2D block scale support
|
||||
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
|
||||
}
|
||||
|
||||
@@ -11,6 +11,14 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
#define CK_TILE_SUPPORTED_QUANT_GROUPS(X) \
|
||||
X(1, 1, 64) /* 1D */ \
|
||||
X(1, 1, 128) /* 1D */ \
|
||||
X(1, 8, 128) /* 2D N=8 */ \
|
||||
X(1, 32, 128) /* 2D N=32 */ \
|
||||
X(1, 64, 128) /* 2D N=64 */ \
|
||||
X(1, 128, 128) /* 2D N=128 */
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -193,6 +201,22 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_ = ADataType_,
|
||||
typename CDataType_ = ADataType_,
|
||||
@@ -288,7 +312,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol");
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
|
||||
.insert("group_size",
|
||||
"1x1x128",
|
||||
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -14,7 +14,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
@@ -113,7 +113,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -146,7 +146,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if(K % QuantGroupSize != 0)
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
|
||||
@@ -155,13 +155,13 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t AQK, BQK;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
@@ -357,7 +357,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
uint32_t QuantGroupSize,
|
||||
typename QuantGroupSize,
|
||||
bool aquant,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
@@ -80,12 +80,11 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
v_block_acc += v_a * v_b;
|
||||
|
||||
// Apply group dequant scale
|
||||
if((k + 1) % QuantGroupSize == 0)
|
||||
if((k + 1) % QuantGroupSize::kK == 0)
|
||||
{
|
||||
float scale = 0.f;
|
||||
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
|
||||
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
|
||||
|
||||
index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
|
||||
index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
|
||||
if constexpr(std::is_same_v<QDataType, float>)
|
||||
{
|
||||
scale = q(outer_dim, inner_dim);
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// BQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
|
||||
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on block distributed tensor.
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
@@ -24,6 +24,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -47,8 +51,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp =
|
||||
@@ -58,13 +61,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
static constexpr auto MIter_2nd_last =
|
||||
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
(WG::kK + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
@@ -46,7 +46,7 @@ struct BlockGemmAQuantBase
|
||||
|
||||
// A is block window on shared memory
|
||||
// AQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
|
||||
// Consecutive QuantGroupSize elements of A are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_,
|
||||
@@ -66,16 +66,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
// Threadblock GEMM tile size
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -101,20 +101,20 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / kQuantGroupSize > 0,
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
|
||||
@@ -46,7 +46,7 @@ struct BlockGemmBQuantBase
|
||||
|
||||
// A is block window on shared memory
|
||||
// BQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
|
||||
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_,
|
||||
@@ -66,16 +66,18 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
// Threadblock GEMM tile size
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t BQPerBlock = KPerBlock / kQuantGroupSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -101,20 +103,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / kQuantGroupSize > 0,
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
@@ -340,23 +342,17 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
}
|
||||
});
|
||||
|
||||
// Need to multiply bquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
// These elements are in different rows, need to get the scale value
|
||||
// for the corresponding row.
|
||||
// Based on bquant's tile distribution, it can be inferred which
|
||||
// lane holds the relevant scale. For example, the scales corresponding
|
||||
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
|
||||
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
constexpr index_t reg_offset = nIter * Traits::BQPerBlock + kQScale;
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
|
||||
@@ -685,9 +685,10 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.QK_B, kargs.N),
|
||||
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(1, kargs.stride_BQ),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
@@ -831,10 +832,10 @@ struct QuantGemmKernel
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block =
|
||||
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
@@ -847,11 +848,12 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
@@ -907,11 +909,12 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -18,6 +18,7 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
|
||||
@@ -25,10 +26,9 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize == 0,
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
|
||||
// Create DRAM tile window for AQ
|
||||
|
||||
@@ -86,6 +86,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -106,12 +110,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -147,7 +150,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize,
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
|
||||
Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
|
||||
// clang-format on
|
||||
}
|
||||
@@ -204,7 +207,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
|
||||
@@ -22,7 +22,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
|
||||
@@ -37,7 +37,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
@@ -99,8 +99,8 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
|
||||
@@ -91,6 +91,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -111,12 +115,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -152,7 +155,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -208,7 +211,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
|
||||
@@ -18,6 +18,7 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -25,11 +26,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize;
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
@@ -38,7 +44,7 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
|
||||
using YPerTile = number<NPerBlock>;
|
||||
using YPerTile = number<NPerBlockBQ>;
|
||||
using XPerTile = number<KPerBlockBQ>;
|
||||
|
||||
auto bq_copy_dram_window =
|
||||
|
||||
@@ -21,11 +21,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -36,9 +37,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
@@ -49,12 +50,13 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
Problem::TransposeC>;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockBQ,
|
||||
NPerBlock,
|
||||
VecLoadSize>;
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockBQ,
|
||||
NPerBlockBQ,
|
||||
Problem::QuantGroupSize::kN>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
@@ -65,8 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
|
||||
@@ -91,7 +91,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
@@ -111,12 +113,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -151,7 +154,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -207,7 +210,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -255,7 +258,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
|
||||
@@ -171,11 +171,9 @@ template <typename BlockGemmShape,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
index_t XPerQ>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
|
||||
@@ -186,34 +184,94 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
|
||||
|
||||
static_assert(num_warps == MWarps * NWarps * KWarps);
|
||||
|
||||
// KWarps > 1 isn't supported
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
// # of elements per thread
|
||||
static constexpr index_t Y = YPerTile;
|
||||
static constexpr index_t YR = 1;
|
||||
|
||||
// Number of iters per warp
|
||||
// MIters are indexed using (Y0, Y1)
|
||||
static constexpr index_t X0 = NIterPerWarp;
|
||||
|
||||
// # of warps in Y dim
|
||||
static constexpr index_t X1 = NWarps;
|
||||
|
||||
static constexpr index_t X2 = WarpGemm::kN;
|
||||
|
||||
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along Y.");
|
||||
|
||||
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
|
||||
///
|
||||
/// This function determines the optimal thread distribution pattern for loading and applying
|
||||
/// quantization scales to the B matrix based on the quantization group size (XPerQ) relative
|
||||
/// to warp dimensions.
|
||||
///
|
||||
/// Three distinct distribution patterns are handled:
|
||||
///
|
||||
/// 1. Fine-grained quantization (XPerQ < WarpGemm::kN):
|
||||
/// - Multiple quantization groups exist within a single warp's N-dimension
|
||||
/// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
|
||||
/// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast
|
||||
/// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
|
||||
///
|
||||
/// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps):
|
||||
/// - Each warp handles exactly one quantization scale
|
||||
/// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN
|
||||
/// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
|
||||
///
|
||||
/// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps):
|
||||
/// - Quantization group spans multiple warps
|
||||
/// - All warps share the same scale value
|
||||
/// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
|
||||
///
|
||||
/// @return A static tile distribution encoding for the BQ scale tensor
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, YR>,
|
||||
tuple<sequence<Y>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
if constexpr(XPerQ < WarpGemm::kN)
|
||||
{
|
||||
// Case 1: Fine-grained - multiple quantization scales within a single warp
|
||||
constexpr index_t Y = YPerTile; // Full Y dimension of tile
|
||||
constexpr index_t YR = 1; // No Y replication needed
|
||||
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
|
||||
constexpr index_t X1 = NWarps; // Number of warps in N-dim
|
||||
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
|
||||
constexpr index_t XR = XPerQ; // Elements per quantization group
|
||||
|
||||
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, YR, XR>,
|
||||
tuple<sequence<Y>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
|
||||
{
|
||||
// Case 2: Medium-grained - one quantization scale per warp
|
||||
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
|
||||
constexpr auto X1 = NWarps / XR; // Warps per unique scale
|
||||
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 2, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1, 1>, sequence<2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else // XPerQ > WarpGemm::kN * NWarps
|
||||
{
|
||||
// Case 3: Coarse-grained - quantization group spans all warps
|
||||
// All warps in N-dimension share the same quantization scale
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<XPerTile>>,
|
||||
tuple<sequence<0, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1>, sequence<2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GroupSizes>
|
||||
struct QuantGroupShape
|
||||
{
|
||||
static constexpr index_t kM = GroupSizes::at(number<0>{});
|
||||
static constexpr index_t kN = GroupSizes::at(number<1>{});
|
||||
static constexpr index_t kK = GroupSizes::at(number<2>{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
return concat('_', "quant_group_shape", concat('x', kM, kN, kK));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename QuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
@@ -48,6 +48,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = QuantGroupSize_;
|
||||
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BLayout;
|
||||
@@ -67,12 +68,13 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
|
||||
|
||||
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
|
||||
static_assert(BlockGemmShape::kM % QuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % QuantGroupSize::kK == 0);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -81,8 +83,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
"QuantGroupSize",
|
||||
kQuantGroupSize);
|
||||
QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -111,7 +112,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename QuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
@@ -137,7 +138,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename QuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
@@ -175,7 +176,7 @@ using GemmRowColTensorQuantPipelineProblem =
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
QuantGroupShape<sequence<1, 1, 1>>, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
|
||||
@@ -15,10 +15,11 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
|
||||
{
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -25,6 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
@@ -68,10 +69,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
using Base::m_preload;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(kKPerBlock + QuantGroupSize - 1) / QuantGroupSize;
|
||||
integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
@@ -89,7 +90,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize);
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
@@ -26,17 +26,17 @@ template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using QDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr uint32_t QuantGroupSize = std::tuple_element_t<9, Tuple>::value;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using QDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<8, Tuple>;
|
||||
using QuantGroupSize = std::tuple_element_t<9, Tuple>;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
|
||||
// Get the quant-type specific data types from traits
|
||||
using QuantTraits = QuantTypeTraits<QuantType>;
|
||||
|
||||
@@ -31,7 +31,7 @@ struct GemmConfigBase
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
@@ -119,9 +119,9 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
@@ -135,7 +135,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// AQuant uses grouped quantization for A matrix
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
|
||||
const ck_tile::index_t stride_AQ =
|
||||
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{}));
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
if constexpr(Base::GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize);
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
@@ -359,11 +359,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::GemmConfig;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto PreshuffleB = Base::PreshuffleB;
|
||||
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr auto PreshuffleB = Base::PreshuffleB;
|
||||
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
@@ -375,8 +375,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// BQuant uses grouped quantization for B matrix
|
||||
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
// BQuant uses block/grouped quantization for B matrix
|
||||
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN);
|
||||
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
|
||||
const ck_tile::index_t stride_BQ = BQK;
|
||||
|
||||
// Generate test data
|
||||
@@ -384,18 +385,18 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_n(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_n);
|
||||
|
||||
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem bq_bqk_bqn_dev_buf(bq_bqk_bqn.get_element_space_size() *
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
@@ -425,25 +426,27 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_n);
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
{
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data());
|
||||
}
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_n_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
0, // QK_A (not used for BQuant)
|
||||
BQK, // QK_B
|
||||
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
@@ -467,7 +470,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
@@ -614,9 +617,9 @@ class TestCkTileGemmRowColQuant
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
@@ -831,9 +834,9 @@ class TestCkTileGemmTensorQuant
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
using typename Base::QuantGroupSize;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
|
||||
@@ -20,7 +20,15 @@ using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantT
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = std::integral_constant<unsigned int, 128>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for each quantization type
|
||||
// clang-format off
|
||||
@@ -53,10 +61,38 @@ using AQuantTypes = ::testing::Types<
|
||||
|
||||
// clang-format off
|
||||
using BQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
// 1d cases with grouping only on k axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
|
||||
// 2d cases with grouping also on the n axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -77,6 +113,7 @@ using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using RowColQuantTypes = ::testing::Types<
|
||||
|
||||
@@ -93,6 +93,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
@@ -135,7 +137,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
128>, // QuantGroupSize
|
||||
QuantGroupSize>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -258,6 +260,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
|
||||
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
@@ -495,7 +499,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
128,
|
||||
QuantGroupSize,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user