mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] ABQuant New Preshuffle (#3638)
* Refactor * Gemm quant improvement * Change preshuffle * Fix * Fix grouped gemm ut * Fix --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -6,6 +6,7 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
|
||||
// template <typename T>
|
||||
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
|
||||
|
||||
void abquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
@@ -135,4 +134,5 @@ void abquant_quantgrouped_instance_factory(
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -10,9 +10,8 @@ using GemmConfig = GemmConfigQuantDecodeInterwave<T>;
|
||||
// template <typename T>
|
||||
// using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
@@ -56,4 +55,5 @@ void aquant_quantgrouped_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -6,9 +6,8 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;
|
||||
|
||||
void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
@@ -52,4 +51,5 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_raw_t,
|
||||
ck_tile::bf16_t,
|
||||
@@ -38,4 +37,5 @@ void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
@@ -55,4 +54,5 @@ void bquant_quantgrouped_bf8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -57,4 +56,5 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
@@ -55,4 +54,5 @@ void bquant_quantgrouped_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -57,4 +56,5 @@ void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
@@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -54,4 +53,5 @@ void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
@@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -54,4 +53,5 @@ void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
@@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -49,4 +48,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
@@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -49,4 +48,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
@@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -56,4 +55,5 @@ void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
@@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -56,4 +55,5 @@ void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -95,51 +95,6 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
return hash_multiple_strings(params);
|
||||
}
|
||||
|
||||
void abquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_tensor_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -153,29 +108,8 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Device ID: " << device_id << std::endl;
|
||||
ck_tile::hip_check_error(hipSetDevice(device_id));
|
||||
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
abquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf16fp4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut);
|
||||
quant_rowcol_instance_factory(lut);
|
||||
quant_tensor_instance_factory(lut);
|
||||
auto& lut = get_kernel_lut();
|
||||
std::cout << "Available kernels: " << lut.size() << std::endl;
|
||||
|
||||
auto key = gen_lut_key(arg_parser);
|
||||
|
||||
|
||||
@@ -6,9 +6,8 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
// NOTE: QuantGroupSize is a place holder. rowcol pipeline does not use QuantGroupSize
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
|
||||
lut[hash_multiple_strings({"fp8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
@@ -27,4 +26,5 @@ void quant_rowcol_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::RowColQuant>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -6,9 +6,8 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
void quant_tensor_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
// NOTE: QuantGroupSize is a place holder. tensor pipeline does not use QuantGroupSize
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
|
||||
lut[hash_multiple_strings({"fp8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
@@ -27,4 +26,5 @@ void quant_tensor_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::TensorQuant>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -11,6 +11,14 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
inline auto& get_kernel_lut()
|
||||
{
|
||||
// In an inline function, function-local static objects in all function definitions are shared
|
||||
// across all translation units.
|
||||
static std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
return lut;
|
||||
}
|
||||
|
||||
inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
{
|
||||
std::hash<std::string> hasher;
|
||||
|
||||
@@ -80,10 +80,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
@@ -553,8 +552,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
if(init_method == 0)
|
||||
@@ -630,7 +628,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else if(init_method == 1)
|
||||
{
|
||||
std::cout << "Monotonic initialization is not supported." << std::endl;
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
@@ -1078,10 +1076,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return pass;
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
// Usage of Two-Matrix Quantization (AB-Quant)
|
||||
template <typename GemmConfig,
|
||||
|
||||
@@ -1137,7 +1137,7 @@ CK_TILE_DEVICE static constexpr auto get_device_arch()
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
|
||||
CK_TILE_DEVICE static constexpr auto get_n_dwords_per_128b() { return 4; }
|
||||
|
||||
namespace detail {
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }
|
||||
|
||||
@@ -69,7 +69,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, GemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
@@ -79,36 +79,40 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
else if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess =
|
||||
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
|
||||
@@ -160,7 +160,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto K0PerThreadRead = AK0 / KThreadRead;
|
||||
|
||||
// check if we exceed all LDS banks
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
|
||||
constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
|
||||
@@ -250,7 +250,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto MLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
@@ -357,7 +357,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
|
||||
|
||||
// check if we exceed all LDS banks
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b();
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
|
||||
constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
|
||||
@@ -450,7 +450,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto NLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
|
||||
@@ -151,6 +151,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
|
||||
constexpr index_t kNPerBlock = TileShape::kN;
|
||||
constexpr index_t kKPerBlock = TileShape::kK;
|
||||
@@ -162,16 +163,18 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KBPerLoad = min(
|
||||
GetKBPerLoad<Problem>(), KRepeatInWave * 16 / static_cast<index_t>(sizeof(BDataType)));
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = KIterPerWarp;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
constexpr index_t KAccess = GetKBPerLoad<Problem>() / KBPerLoad;
|
||||
static_assert(TileShape::flatKPerWarp == KAccess * KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
@@ -181,16 +184,16 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KAccess, KWavePerBlk, KThdPerWave, KBPerLoad>>,
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2, 3>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 2, 1, 2>,
|
||||
sequence<0, 0, 3, 3>>{});
|
||||
sequence<1, 2, 1, 2, 2>,
|
||||
sequence<0, 0, 3, 1, 4>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -256,13 +259,22 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
constexpr index_t KLaneBytes =
|
||||
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
|
||||
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
NumAccess>;
|
||||
|
||||
using BlockWeightPreshufflePolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
|
||||
@@ -693,13 +693,13 @@ struct QuantGemmKernel
|
||||
{
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
constexpr auto warp_k = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
index_t kFlatKSplit = GemmPipeline::flatKPerWarp * (k_size / warp_k);
|
||||
index_t kFlatK = GemmPipeline::flatKPerWarp * (kargs.K / warp_k);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatN, kFlatKSplit),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -52,11 +52,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KBPerLoad =
|
||||
min(GetKBPerLoad<Problem>(), 16 / static_cast<index_t>(sizeof(BDataType)));
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
@@ -64,8 +66,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
constexpr index_t KRepeat = GetKBPerLoad<Problem>() / KBPerLoad;
|
||||
static_assert(TileShape::flatKPerWarp == KRepeat * KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
@@ -98,13 +100,23 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
constexpr index_t KLaneBytes =
|
||||
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
|
||||
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
NumAccess>;
|
||||
|
||||
// TODO : Use a custom block policy for AsBrCr
|
||||
using BlockGemmPolicy =
|
||||
|
||||
Reference in New Issue
Block a user