mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Tile] Support for group size 128 for Preshuffle quant for 2d block scale gemm (#3462)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * initial commit with prints * debugging * fine-grained working * debugging medium grained * fixing the tile window * formatting * enabling prefill shapes * working prefill shapes * formatted * clean up * code cleanup * bug fix after merging with develop * G128 working for both prefill and decode shapes for preshufflequant * clean up after merging with develop * fixing group 64 for decode shapes * non preshufflequant working for group size 128 * enable preshuffleb and preshufflequant with variour group sizes * reduce build time by splitting example into diff datatype files * Adding tests for preshuffleQuant * address review comment * fix for gfx1201 * compile time fix for gfx1201 * clang formatted --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
@@ -20,9 +20,18 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp
|
||||
gemm_quant_rowcol.cpp
|
||||
gemm_quant_tensor.cpp
|
||||
)
|
||||
|
||||
@@ -49,4 +49,10 @@ void bquant_quantgrouped_bf8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -51,4 +51,10 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -49,4 +49,10 @@ void bquant_quantgrouped_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -51,4 +51,10 @@ void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,222 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -1,270 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -111,11 +111,29 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
void bquant_quantgrouped_preshuffleb_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
void bquant_quantgrouped_preshuffleb_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
@@ -144,9 +162,18 @@ int main(int argc, char* argv[])
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf16fp4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut);
|
||||
quant_rowcol_instance_factory(lut);
|
||||
quant_tensor_instance_factory(lut);
|
||||
|
||||
|
||||
@@ -319,7 +319,23 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
// constexpr index_t reg_offset = nIter;
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN))
|
||||
{
|
||||
if constexpr(Traits::NPerBlock ==
|
||||
GemmTraits::QuantGroupSize::kN)
|
||||
return kQScale;
|
||||
else
|
||||
return nIter; // for prefill needs kQscale, for decode needs
|
||||
// nIter
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
}
|
||||
}();
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
|
||||
@@ -887,23 +887,27 @@ struct QuantGemmKernel
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n =
|
||||
TilePartitioner::NPerBlock /
|
||||
QuantGroupSize::kN; // Number of N-dimension quantization groups per block
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
|
||||
I1); // Number of N-dimension elements per warp
|
||||
constexpr auto warp_per_group =
|
||||
(QuantGroupSize::kN <
|
||||
warp_n) // Determine how many warps share the same scale in N-dimension
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
constexpr auto bqk_per_block =
|
||||
TilePartitioner::KPerBlock /
|
||||
QuantGroupSize::kK; // Number of K-dimension quantization groups per block
|
||||
constexpr auto
|
||||
tile_window_width = // The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
// to ensure coalesced memory access.
|
||||
|
||||
// Number of N-dimension quantization groups per block
|
||||
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
|
||||
|
||||
// Number of N-dimension elements per warp
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
// Determine how many warps share the same scale in N-dimension
|
||||
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
|
||||
// Number of K-dimension quantization groups per block
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
// The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
// to ensure coalesced memory access.
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
|
||||
// Adapts based on fine vs coarse quantization granularity:
|
||||
@@ -916,23 +920,42 @@ struct QuantGemmKernel
|
||||
// height = block_n
|
||||
constexpr auto tile_window_height =
|
||||
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
|
||||
auto block_n_idx =
|
||||
i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
|
||||
// block index.
|
||||
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
|
||||
|
||||
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
block_n_idx = block_n_idx >> 1;
|
||||
}
|
||||
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto tensor_dim =
|
||||
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: 1;
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
number<tensor_dim>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
@@ -940,7 +963,7 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
make_tuple(number<tensor_dim>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
|
||||
@@ -26,14 +26,15 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(QuantGroupSize::kN <= NPerBlock) ? NPerBlock / QuantGroupSize::kN : 1;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
// "NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
|
||||
|
||||
@@ -45,7 +45,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock)
|
||||
? NPerBlock / Problem::QuantGroupSize::kN
|
||||
: 1;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
@@ -66,7 +66,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
(QuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN)
|
||||
: 1;
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
|
||||
|
||||
@@ -240,20 +240,26 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
//
|
||||
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
|
||||
// → 2 scales per warp in N, 2 K-groups per block
|
||||
constexpr auto N1 = BlockGemmShape::kK /
|
||||
KPerQ; // Number of K-dimension quantization groups per block,
|
||||
// Each K-group of KPerQ elements shares the same scale.
|
||||
constexpr auto N0 =
|
||||
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
|
||||
// <= WarpGemm::kN, each warp handles multiple scales.
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
|
||||
|
||||
// N1: Number of K-dimension quantization groups per block,
|
||||
// Each K-group of KPerQ elements shares the same scale.
|
||||
// N0: Number of scales per warp in N-dimension, Since NPerQ
|
||||
// <= WarpGemm::kN, each warp handles multiple scales.
|
||||
// N2: Elements per thread
|
||||
// NR1: Elements sharing the same scale in N-dimension
|
||||
// NR0: Interleave factor to ensure full warp utilization
|
||||
// K1: Number of warps distributed along this dimension
|
||||
// K0: Iterations per warp to cover the K-tile
|
||||
// KR: No replication in K-dimension
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
|
||||
constexpr auto N0 = WarpGemm::kN / NPerQ;
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = NPerQ;
|
||||
constexpr auto NR0 =
|
||||
warp_size /
|
||||
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
|
||||
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
|
||||
constexpr auto KR = 1; // No replication in K-dimension
|
||||
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
|
||||
constexpr auto K1 = NWarps;
|
||||
constexpr auto K0 = KPerTile / K1;
|
||||
constexpr auto KR = 1;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
@@ -275,15 +281,24 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
|
||||
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
|
||||
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
|
||||
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
|
||||
// KR: Number of warps sharing the same scale
|
||||
// K1: Number of distinct warp groups (unique scales)
|
||||
// K0: Iterations to cover K-tile per warp group
|
||||
// N1: K-dimension quantization groups
|
||||
// N0: Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
|
||||
// N2: Elements per thread
|
||||
// NR1: Scale broadcast factor (full NPerQ)
|
||||
// NR0: Remaining interleave factor
|
||||
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN;
|
||||
constexpr auto K1 = NWarps / KR;
|
||||
constexpr auto K0 = KPerTile / K1;
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
|
||||
constexpr auto N0 = 1;
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = NPerQ;
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
@@ -303,12 +318,19 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
//
|
||||
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
|
||||
// → 128 >= 16*4=64, so all 4 warps use the same scale
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = 32; // Fixed broadcast size
|
||||
|
||||
// N1: K-dimension quantization groups
|
||||
// N0: Minimal (1) since scale is shared across N
|
||||
// N2: Elements per thread
|
||||
// NR1: Fixed broadcast size
|
||||
// NR0: Remaining interleave factor
|
||||
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
|
||||
constexpr auto N0 = 1;
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = 32;
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
|
||||
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
|
||||
|
||||
@@ -79,10 +79,8 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % AQuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0);
|
||||
static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % BQuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % BQuantGroupSize::kK == 0);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
|
||||
@@ -144,23 +144,32 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// Insert LDS read/write groups periodically based on ds_rep.
|
||||
// The % pattern staggers READ and WRITE so they don't collapse
|
||||
// into the same cycle in the model.
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
|
||||
if constexpr(ds_rep > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
|
||||
}
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
|
||||
}
|
||||
|
||||
if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0)
|
||||
{
|
||||
if constexpr(ds_write_inst > 0)
|
||||
if(i_inst % ds_rep == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
|
||||
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
if constexpr(ds_rep > 0)
|
||||
{
|
||||
if(i_inst % ds_rep == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(buffer_load_rep > 0)
|
||||
{
|
||||
if(i_inst % buffer_load_rep == 0)
|
||||
{
|
||||
if constexpr(ds_write_inst > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
|
||||
}
|
||||
}
|
||||
}
|
||||
// Always mark some VALU work in the loop to reflect auxiliary scalar
|
||||
@@ -354,7 +363,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
@@ -431,7 +440,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
@@ -468,7 +477,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
|
||||
@@ -117,6 +117,27 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (with PreshuffleQuant) - split into 4 files
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_decode_1d
|
||||
test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_decode_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_1d
|
||||
test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_decode_2d
|
||||
test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_decode_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_2d
|
||||
test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_preshuffleQuant_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# RowColQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_rowcol
|
||||
test_gemm_quant_rowcol.cpp
|
||||
@@ -152,6 +173,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_bquant_preshuffle_tiled_permute
|
||||
test_tile_gemm_quant_bquant_preshuffle_decode_2d
|
||||
test_tile_gemm_quant_bquant_preshuffle_prefill_2d
|
||||
# BQuant preshuffleQuant tests
|
||||
test_tile_gemm_quant_bquant_preshuffleQuant_decode_1d
|
||||
test_tile_gemm_quant_bquant_preshuffleQuant_prefill_1d
|
||||
test_tile_gemm_quant_bquant_preshuffleQuant_decode_2d
|
||||
test_tile_gemm_quant_bquant_preshuffleQuant_prefill_2d
|
||||
# Other quant tests
|
||||
test_tile_gemm_quant_rowcol
|
||||
test_tile_gemm_quant_tensor
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Decode Config 1D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshuffleDecode1DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Decode 1D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode1DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Decode 2D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshuffleDecode2DTypes = ::testing::Types<
|
||||
// 2d cases with preshuffle B
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantDecode, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Decode 2D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode2DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Prefill Config 1D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshufflePrefill1DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Prefill 1D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill1DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Prefill 2D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BPreshufflePrefill2DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D8N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D16N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant Preshuffle Prefill 2D
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill2DTypes);
|
||||
|
||||
// BQuant PreshuffleB tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -19,10 +19,11 @@ using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Decode 2D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
@@ -37,7 +38,9 @@ using BPreshuffleDecode2DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -19,10 +19,11 @@ using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
using GroupSize2D16N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for BQuant Preshuffle tests - Prefill 2D
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
@@ -44,7 +45,11 @@ using BPreshufflePrefill2DTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D128N>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D128N>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -53,11 +53,20 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
|
||||
};
|
||||
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
};
|
||||
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
};
|
||||
|
||||
struct GemmConfigMxFp4 : public GemmConfigBase
|
||||
@@ -89,42 +98,26 @@ struct GemmConfigPadding : public GemmConfigBase
|
||||
static constexpr bool kPadK = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBDecode : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleBDecode : public GemmConfigDecode
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPrefill : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
|
||||
|
||||
Reference in New Issue
Block a user