[CK_TILE] ABQuant New Preshuffle (#3638)

* Refactor

* Gemm quant improvement

* Change preshuffle

* Fix

* Fix grouped gemm ut

* Fix

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Yi DING
2026-01-28 15:46:49 +08:00
committed by GitHub
parent 91e32f305f
commit 8e3d84aba3
32 changed files with 182 additions and 213 deletions

View File

@@ -6,6 +6,7 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")

View File

@@ -12,9 +12,8 @@ using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
// template <typename T>
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
void abquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
@@ -135,4 +134,5 @@ void abquant_quantgrouped_instance_factory(
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
}
return 0;
}();

View File

@@ -10,9 +10,8 @@ using GemmConfig = GemmConfigQuantDecodeInterwave<T>;
// template <typename T>
// using GemmConfig = GemmConfigQuantPrefill<T>;
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings(
{"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
@@ -56,4 +55,5 @@ void aquant_quantgrouped_instance_factory(
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
}
return 0;
}();

View File

@@ -6,9 +6,8 @@
template <typename T>
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;
void aquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings(
{"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
@@ -52,4 +51,5 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_bf16fp4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t,
ck_tile::bf16_t,
@@ -38,4 +37,5 @@ void bquant_quantgrouped_bf16fp4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
#ifndef CK_GFX950_SUPPORT
@@ -55,4 +54,5 @@ void bquant_quantgrouped_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -57,4 +56,5 @@ void bquant_quantgrouped_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
#ifndef CK_GFX950_SUPPORT
@@ -55,4 +54,5 @@ void bquant_quantgrouped_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -57,4 +56,5 @@ void bquant_quantgrouped_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
@@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -54,4 +53,5 @@ void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
@@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -54,4 +53,5 @@ void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
@@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -49,4 +48,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
@@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -49,4 +48,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
@@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -56,4 +55,5 @@ void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
@@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
@@ -56,4 +55,5 @@ void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
return 0;
}();

View File

@@ -95,51 +95,6 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
return hash_multiple_strings(params);
}
void abquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void aquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_bf16fp4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void quant_rowcol_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void quant_tensor_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -153,29 +108,8 @@ int main(int argc, char* argv[])
std::cout << "Device ID: " << device_id << std::endl;
ck_tile::hip_check_error(hipSetDevice(device_id));
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
abquant_quantgrouped_instance_factory(lut);
aquant_quantgrouped_instance_factory(lut);
aquant_quantgrouped_preshufflequant_instance_factory(lut);
bquant_quantgrouped_fp8_instance_factory(lut);
bquant_quantgrouped_bf8_instance_factory(lut);
bquant_quantgrouped_fp8i4_instance_factory(lut);
bquant_quantgrouped_bf8i4_instance_factory(lut);
bquant_quantgrouped_bf16fp4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut);
bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut);
bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut);
bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut);
bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut);
bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut);
bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut);
bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut);
bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut);
bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut);
quant_rowcol_instance_factory(lut);
quant_tensor_instance_factory(lut);
auto& lut = get_kernel_lut();
std::cout << "Available kernels: " << lut.size() << std::endl;
auto key = gen_lut_key(arg_parser);

View File

@@ -6,9 +6,8 @@
template <typename T>
using GemmConfig = GemmConfigQuantDecode<T>;
void quant_rowcol_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
// NOTE: QuantGroupSize is a place holder. rowcol pipeline does not use QuantGroupSize
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
lut[hash_multiple_strings({"fp8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) {
@@ -27,4 +26,5 @@ void quant_rowcol_instance_factory(
QuantGroupSize,
ck_tile::QuantType::RowColQuant>(arg_parser);
};
}
return 0;
}();

View File

@@ -6,9 +6,8 @@
template <typename T>
using GemmConfig = GemmConfigQuantDecode<T>;
void quant_tensor_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
static auto _ = []() {
auto& lut = get_kernel_lut();
// NOTE: QuantGroupSize is a place holder. tensor pipeline does not use QuantGroupSize
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
lut[hash_multiple_strings({"fp8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) {
@@ -27,4 +26,5 @@ void quant_tensor_instance_factory(
QuantGroupSize,
ck_tile::QuantType::TensorQuant>(arg_parser);
};
}
return 0;
}();

View File

@@ -11,6 +11,14 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
inline auto& get_kernel_lut()
{
// In an inline function, function-local static objects in all function definitions are shared
// across all translation units.
static std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
return lut;
}
inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
{
std::hash<std::string> hasher;

View File

@@ -80,10 +80,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
@@ -553,8 +552,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
}
std::random_device rd;
std::mt19937 gen(rd());
std::mt19937 gen(42);
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
@@ -630,7 +628,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
return -1;
}
else if(init_method == 2)
{
@@ -1078,10 +1076,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
return -1;
}
return pass;
return pass ? 0 : -1;
}
// Usage of Two-Matrix Quantization (AB-Quant)
template <typename GemmConfig,