mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * addressing review comments * fixing CI issue * addressing reveiw comments * formatting * formatting * fixing aquant operator overlaoding * formatting --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -10,11 +10,14 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_executable(${EXE_NAME} EXCLUDE_FROM_ALL
|
||||
gemm_quant.cpp
|
||||
gemm_aquant_quantgrouped.cpp
|
||||
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_prefill_bf8.cpp
|
||||
gemm_bquant_quantgrouped_prefill_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
|
||||
gemm_aquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
|
||||
gemm_quant_rowcol.cpp
|
||||
gemm_quant_tensor.cpp
|
||||
)
|
||||
|
||||
@@ -33,47 +33,50 @@ mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Compile the quant kernels
|
||||
make tile_example_gemm_quant_basic -j
|
||||
make tile_example_gemm_quant -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
|
||||
This will result in an executable `build/bin/tile_example_gemm_quant`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-h Print help message (default:false)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-a_layout A tensor data layout - Row or Column (default:R)
|
||||
-b_layout B tensor data layout - Row or Column (default:C)
|
||||
-bq_layout Bq tensor data layout - Row or Column (default:C)
|
||||
-c_layout C tensor data layout - Row or Column (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_q Tensor AQ stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
|
||||
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
|
||||
-warmup Number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:1000)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k SplitK value (default:1)
|
||||
-device Device id that will be used to run the kernel (default:0)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache Flush cache before running the kernel (default:true)
|
||||
-rotating_count Rotating count (default:1000)
|
||||
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
|
||||
-preshuffleb Enable preshuffle of tensor B (default:false)
|
||||
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
|
||||
-h Print help message (default:false)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-a_layout A tensor data layout - R for Row or C for Column (default:R)
|
||||
-b_layout B tensor data layout - R for Row or C for Column (default:C)
|
||||
-bq_layout Bq tensor data layout - R for Row or C for Column (default:C)
|
||||
-c_layout C tensor data layout - R for Row or C for Column (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_q Tensor AQ stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
|
||||
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
|
||||
-warmup Number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:1000)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k SplitK value (default:1)
|
||||
-device Device id that will be used to run the kernel (default:0)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache Flush cache before running the kernel (default:true)
|
||||
-rotating_count Rotating count (default:1000)
|
||||
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
|
||||
-preshuffleb Enable preshuffle of tensor B (default:false)
|
||||
-preshufflequant Enable preshuffle of quant tensor (default:false)
|
||||
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
|
||||
```
|
||||
|
||||
User need to select correct mapping of config for each quant mode:
|
||||
|
||||
| | quant_mode as runtime argument | Config in cpp file |
|
||||
|:--------|:-----:|-------|
|
||||
| For selecting AQuant | aquant | GemmConfigQuant |
|
||||
| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant |
|
||||
| For selecting BQuant | bquant | GemmConfigQuant |
|
||||
| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill
|
||||
| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant |
|
||||
| | quant_mode as runtime argument | Corresponding cpp file | GemmConfig at the top of cpp file |
|
||||
|:--------|:-----:|:-----:|-------|
|
||||
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
|
||||
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
|
||||
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
|
||||
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
|
||||
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
| For selecting RowCol quant | rowcolquant | gemm_quant_rowcol| GemmConfigRowColQuant |
|
||||
|
||||
|
||||
@@ -4,14 +4,15 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuant<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
@@ -19,8 +20,9 @@ void aquant_quantgrouped_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
@@ -28,7 +30,7 @@ void aquant_quantgrouped_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] =
|
||||
lut[hash_multiple_strings({"fp8i4", "aquant", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
@@ -39,7 +41,7 @@ void aquant_quantgrouped_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] =
|
||||
lut[hash_multiple_strings({"bf8i4", "aquant", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
|
||||
@@ -4,50 +4,52 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<T>;
|
||||
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
{"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
{"bf8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] =
|
||||
lut[hash_multiple_strings({"fp8i4", "aquant", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] =
|
||||
lut[hash_multiple_strings({"bf8i4", "aquant", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -18,28 +18,33 @@ void bquant_quantgrouped_bf8_instance_factory(
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
@@ -20,28 +20,33 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
@@ -18,28 +18,33 @@ void bquant_quantgrouped_fp8_instance_factory(
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
@@ -20,28 +20,33 @@ void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -43,6 +43,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
|
||||
.insert("group_size",
|
||||
"1x1x128",
|
||||
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
|
||||
@@ -58,11 +59,21 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::vector<std::string> params = {data_type, quant_mode};
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
std::string preshufflequant =
|
||||
arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant";
|
||||
params.push_back(preshufflequant);
|
||||
}
|
||||
if(quant_mode == "bquant")
|
||||
{
|
||||
std::string preshuffleb =
|
||||
arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb";
|
||||
params.push_back(preshuffleb);
|
||||
|
||||
std::string preshufflequant =
|
||||
arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant";
|
||||
params.push_back(preshufflequant);
|
||||
}
|
||||
if(quant_mode != "rowcol" && quant_mode != "tensor")
|
||||
{
|
||||
@@ -76,6 +87,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
@@ -86,6 +99,10 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_tensor_instance_factory(
|
||||
@@ -106,11 +123,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
aquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
quant_rowcol_instance_factory(lut);
|
||||
quant_tensor_instance_factory(lut);
|
||||
|
||||
@@ -122,9 +142,9 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr
|
||||
<< "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported."
|
||||
<< std::endl;
|
||||
std::cerr << "Error: Combination of prec, quant_mode, preshuffleb, preshufflequant, and "
|
||||
"group_size not supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuant<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuant<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
void quant_tensor_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
|
||||
@@ -110,7 +110,7 @@ struct GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigQuant : public GemmConfigBase
|
||||
struct GemmConfigQuantDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
@@ -142,7 +142,7 @@ struct GemmConfigRowColQuant : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
@@ -161,7 +161,7 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
@@ -184,7 +184,14 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode
|
||||
: public GemmConfigPreshuffleB_BQuant_Decode<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -206,6 +213,13 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
{
|
||||
@@ -222,6 +236,12 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
|
||||
{
|
||||
|
||||
@@ -557,7 +557,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
aq_dev_buf_ptr =
|
||||
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes());
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
@@ -626,8 +625,24 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr);
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_permuted_host, GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
|
||||
}
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(*bq_tensor_ptr);
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user