mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK TILE GEMM] Refactor block_scale_gemm examples (#3181)
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Split cpp file to reduce building time
- Support multiple GemmConfig
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Update Readme
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Add support for rowcol and tensor GEMM operations
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Update README
* [CK TILE GEMM] Refactor block_scale_gemm examples
- Set quant group size to (1, 1, 64) for targets excluding gfx950, where warp tile size (16, 16, 128) is incompatible.
[ROCm/composable_kernel commit: 6fd8ddabe7]
This commit is contained in:
@@ -6,8 +6,19 @@ endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
|
||||
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
set(EXE_NAME tile_example_gemm_quant)
|
||||
add_executable(${EXE_NAME} EXCLUDE_FROM_ALL
|
||||
gemm_quant.cpp
|
||||
gemm_aquant_quantgrouped.cpp
|
||||
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_prefill_bf8.cpp
|
||||
gemm_bquant_quantgrouped_prefill_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
|
||||
gemm_quant_rowcol.cpp
|
||||
gemm_quant_tensor.cpp
|
||||
)
|
||||
target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-b batch size (default:1)
|
||||
-m m dimension (default:1024)
|
||||
-n n dimension (default:2048)
|
||||
-k k dimension (default:64)
|
||||
-a_layout Tensor A data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: C)
|
||||
-c_layout Tensor C data layout (default: R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
|
||||
-h Print help message (default:false)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-a_layout A tensor data layout - Row or Column (default:R)
|
||||
-b_layout B tensor data layout - Row or Column (default:C)
|
||||
-bq_layout Bq tensor data layout - Row or Column (default:C)
|
||||
-c_layout C tensor data layout - Row or Column (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_q Tensor AQ stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
|
||||
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
|
||||
-warmup Number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:1000)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k SplitK value (default:1)
|
||||
-device Device id that will be used to run the kernel (default:0)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache Flush cache before running the kernel (default:true)
|
||||
-rotating_count Rotating count (default:1000)
|
||||
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
|
||||
-preshuffleb Enable preshuffle of tensor B (default:false)
|
||||
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
|
||||
```
|
||||
|
||||
User need to select correct mapping of config for each quant mode:
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuant<T>;
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
#endif
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
130
example/ck_tile/38_block_scale_gemm/gemm_quant.cpp
Normal file
130
example/ck_tile/38_block_scale_gemm/gemm_quant.cpp
Normal file
@@ -0,0 +1,130 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("h", "false", "Print help message")
|
||||
.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row or Column")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row or Column")
|
||||
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row or Column")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_q", "0", "Tensor AQ stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0: No validation, 1: Validation on CPU, 2: Validation on GPU")
|
||||
.insert("prec",
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"or bf8i4")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "SplitK value")
|
||||
.insert("device", "0", "Device id that will be used to run the kernel")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "Flush cache before running the kernel")
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("group_size",
|
||||
"1x1x128",
|
||||
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
|
||||
std::vector<std::string> params = {data_type, quant_mode};
|
||||
|
||||
if(quant_mode == "bquant")
|
||||
{
|
||||
std::string preshuffleb =
|
||||
arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb";
|
||||
params.push_back(preshuffleb);
|
||||
}
|
||||
if(quant_mode != "rowcol" && quant_mode != "tensor")
|
||||
{
|
||||
// NOTE: rowcol and tensor pipeline do not use group size
|
||||
std::string group_size_str = arg_parser.get_str("group_size");
|
||||
params.push_back(group_size_str);
|
||||
}
|
||||
|
||||
return hash_multiple_strings(params);
|
||||
}
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_tensor_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result || arg_parser.get_bool("h"))
|
||||
{
|
||||
arg_parser.print();
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto device_id = arg_parser.get_int("device");
|
||||
std::cout << "Device ID: " << device_id << std::endl;
|
||||
ck_tile::hip_check_error(hipSetDevice(device_id));
|
||||
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
aquant_quantgrouped_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
quant_rowcol_instance_factory(lut);
|
||||
quant_tensor_instance_factory(lut);
|
||||
|
||||
auto key = gen_lut_key(arg_parser);
|
||||
|
||||
if(lut.find(key) != lut.end())
|
||||
{
|
||||
return lut[key](arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr
|
||||
<< "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,428 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise>
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::ADataType>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<GemmShape>;
|
||||
|
||||
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
ALayout, // for AQLayout
|
||||
BLayout, // for BQLayout
|
||||
false,
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
// This example only supports BQuant (no AQuant)
|
||||
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
// row-col and tensor quants use the regular pipeline, A/B quants use their own
|
||||
using PipelineProblem = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
QuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
QuantGroupSize,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>, // memory pipeline hardcoded
|
||||
// for aquant
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << PipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>
|
||||
rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.c_ptr,
|
||||
0,
|
||||
args.M * args.N * sizeof(typename TypeConfig::CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant) &&
|
||||
GemmConfig::PreshuffleB)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for A.");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Forward declaration for dispatch function
|
||||
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
|
||||
int dispatch_by_data_type(const std::string& data_type,
|
||||
const std::string& quant_mode,
|
||||
const std::string& a_layout,
|
||||
const std::string& b_layout,
|
||||
int argc,
|
||||
char* argv[]);
|
||||
|
||||
// Helper function to parse group size string "MxNxK"
|
||||
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
|
||||
{
|
||||
int m = 1, n = 1, k = 128;
|
||||
|
||||
size_t first_x = group_size_str.find('x');
|
||||
if(first_x == std::string::npos)
|
||||
{
|
||||
// Single number provided, assume it's the K dimension
|
||||
k = std::stoi(group_size_str);
|
||||
return {1, 1, k};
|
||||
}
|
||||
|
||||
size_t second_x = group_size_str.find('x', first_x + 1);
|
||||
if(second_x == std::string::npos)
|
||||
{
|
||||
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
|
||||
}
|
||||
|
||||
m = std::stoi(group_size_str.substr(0, first_x));
|
||||
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
|
||||
k = std::stoi(group_size_str.substr(second_x + 1));
|
||||
|
||||
return {m, n, k};
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
std::string group_size_str = arg_parser.get_str("group_size");
|
||||
|
||||
auto [m_group, n_group, k_group] = parse_group_size(group_size_str);
|
||||
|
||||
// Dispatch based on group size (M, N, K)
|
||||
return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
|
||||
using QuantGroupSize = decltype(QGS_);
|
||||
return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
|
||||
data_type, quant_mode, a_layout, b_layout, argc, argv);
|
||||
});
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
|
||||
int dispatch_by_data_type(const std::string& data_type,
|
||||
const std::string& quant_mode,
|
||||
const std::string& a_layout,
|
||||
const std::string& b_layout,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
// This example ONLY supports BQuant for 2D block scale quantization
|
||||
if(quant_mode != "bquant")
|
||||
{
|
||||
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
|
||||
}
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename> typename GemmConfig, typename F>
|
||||
int dispatch_group_size_ct(int m, int n, int k, F&& f)
|
||||
{
|
||||
// This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
|
||||
#define DISPATCH_ONE(M, N, K) \
|
||||
if(m == M && n == N && k == K) \
|
||||
{ \
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
|
||||
return f(QuantGroupSize{}); \
|
||||
}
|
||||
|
||||
CK_TILE_SUPPORTED_QUANT_GROUPS(DISPATCH_ONE)
|
||||
|
||||
#undef DISPATCH_ONE
|
||||
|
||||
throw std::runtime_error(
|
||||
"Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X).");
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigBQuantPrefill_Wmma>(argc, argv);
|
||||
#else
|
||||
// Use non-preshuffled GemmConfig for 2D block scale support
|
||||
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
30
example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp
Normal file
30
example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuant<T>;
|
||||
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
// NOTE: QuantGroupSize is a place holder. rowcol pipeline does not use QuantGroupSize
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
|
||||
lut[hash_multiple_strings({"fp8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::RowColQuant>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::RowColQuant>(arg_parser);
|
||||
};
|
||||
}
|
||||
30
example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp
Normal file
30
example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuant<T>;
|
||||
|
||||
void quant_tensor_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
// NOTE: QuantGroupSize is a place holder. tensor pipeline does not use QuantGroupSize
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
|
||||
lut[hash_multiple_strings({"fp8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::TensorQuant>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::TensorQuant>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,12 +11,18 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
#define CK_TILE_SUPPORTED_QUANT_GROUPS(X) \
|
||||
X(1, 1, 64) /* 1D */ \
|
||||
X(1, 1, 128) /* 1D */ \
|
||||
X(1, 8, 128) /* 2D N=8 */ \
|
||||
X(1, 32, 128) /* 2D N=32 */ \
|
||||
X(1, 64, 128) /* 2D N=64 */
|
||||
inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
{
|
||||
std::hash<std::string> hasher;
|
||||
size_t combined_hash = 0;
|
||||
for(const auto& str : inputs)
|
||||
{
|
||||
// Hash combine using golden ratio constant and bit shifts for good distribution and
|
||||
// order-dependent mixing
|
||||
combined_hash ^= hasher(str) + 0x9e3779b9 + (combined_hash << 6) + (combined_hash >> 2);
|
||||
}
|
||||
return combined_hash;
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -293,37 +299,3 @@ struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_q", "0", "Tensor AQ stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec",
|
||||
"fp8",
|
||||
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
|
||||
.insert("group_size",
|
||||
"1x1x128",
|
||||
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,233 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename CDEElementWise>
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::ADataType>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<GemmShape>;
|
||||
|
||||
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
ALayout, // for AQLayout
|
||||
BLayout, // for BQLayout
|
||||
false,
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
// This example only supports BQuant (no AQuant)
|
||||
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
// row-col and tensor quants use the regular pipeline, A/B quants use their own
|
||||
using PipelineProblem = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
QuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
QuantGroupSize,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << PipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>
|
||||
rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.c_ptr,
|
||||
0,
|
||||
args.M * args.N * sizeof(typename TypeConfig::CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
@@ -120,18 +342,13 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using AQDataType = typename TypeConfig::QDataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
@@ -522,3 +739,45 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant) &&
|
||||
GemmConfig::PreshuffleB)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
|
||||
{
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
|
||||
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for A.");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user