[CK-Tile] Fix quant example code (#2813)

* initial commit

* remove extra files

* fixing errors

* updated ReadMe file for mapping of diff quants with diff configs

* addressing review comments

* addressing review comments

* Resolved merge conflicts

* [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled

The get_preshuffle_or was not working as expected, which led to incorrect behavior
in the quantization preshuffle process. This change replaces it with the more reliable
is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied.

---------

Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-09-10 17:15:39 -07:00
committed by GitHub
parent b4207c01c7
commit 80a61afb9b
10 changed files with 343 additions and 1288 deletions

View File

@@ -8,12 +8,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp)
target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp)
target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -3,6 +3,7 @@
This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation.
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
- Row and Column-wise scaled: scaling implemented in Epilogue
## build
@@ -38,3 +39,13 @@ args:
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-quant_mode Which quant method to use (aquant, rowcol)
```
User need to select correct mapping of config for each quant mode:
| | quant_mode as runtime argument | Config in cpp file |
|:--------|:-----:|-------|
| For selecting AQuant | aquant | GemmConfigQuant |
| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant |
| For selecting BQuant | bquant | GemmConfigQuant |
| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant |

View File

@@ -1,234 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::AQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
#include "run_gemm_aquant_example.inc"
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return 0;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleQuant>(argc, argv);
}

View File

@@ -1,231 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_bquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::BQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::BQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
;
}
#include "run_gemm_bquant_example.inc"
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for B.");
}
return 0;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }

View File

@@ -18,58 +18,45 @@ template <typename GemmConfig,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
ck_tile::QuantType QuantMode>
ck_tile::QuantType QuantMode,
typename CDEElementWise>
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = GemmConfig::kPadM;
constexpr bool kPadN = GemmConfig::kPadN;
constexpr bool kPadK = GemmConfig::kPadK;
constexpr int kBlockPerCu = GemmConfig::kBlockPerCu;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
// B datatype is safe to use as compute type as it should be at least fp8
using ComputeDataType = typename TypeConfig::BDataType;
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
typename TypeConfig::ADataType>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<GemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout,
QuantMode>;
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout,
QuantMode>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
GemmShape,
GemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -80,36 +67,50 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using CodegenPipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>,
using PipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant,
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>>;
using CodegenGemmPipeline =
tail_number_v>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<CodegenPipelineProblem>>;
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::QDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
QuantGroupSize,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
@@ -119,18 +120,18 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue, QuantMode>;
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -150,16 +151,16 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << PipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
@@ -199,33 +200,9 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}
void print_help(const char* program_name)
{
std::cout
<< "Usage: " << program_name << " [OPTIONS]\n"
<< "\n Parameters:\n"
<< " -quant_mode=MODE aquant (A quantization) or rowcol (row/column quantization)\n"
<< " -v=LEVEL 0=No validation, 1=CPU validation, 2=GPU validation\n"
<< " -prec=TYPE Data types: fp8, bf8, i4fp8, i4bf8, i4f32fp8, i4f32bf8\n"
<< " -m, -n, -k=SIZE Matrix dimensions (M×K) @ (K×N) = (M×N)\n"
<< "\nExample:\n"
<< " " << program_name << " -quant_mode=rowcol -v=1 -prec=fp8\n"
<< std::endl;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
for(int i = 1; i < argc; ++i)
{
if(strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0)
{
print_help(argv[0]);
return 0;
}
}
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
@@ -239,7 +216,7 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
if(quant_mode == "aquant")
{
@@ -249,6 +226,14 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
@@ -259,12 +244,14 @@ int run_gemm_example(int argc, char* argv[])
}
else
{
throw std::runtime_error("Unsupported quantization mode! Use 'aquant' or 'rowcol'");
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
}
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
if(quant_mode == "aquant")
{
@@ -274,6 +261,14 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
@@ -284,19 +279,20 @@ int run_gemm_example(int argc, char* argv[])
}
else
{
throw std::runtime_error("Unsupported quantization mode! Use 'aquant' or 'rowcol'");
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
}
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
@@ -312,12 +308,12 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
@@ -329,42 +325,46 @@ int run_gemm_example(int argc, char* argv[])
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
}
else if(data_type == "i4f32fp8")
else if(data_type == "fp8i4")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
if(quant_mode == "aquant")
if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
"Unsupported quantization mode for this datatype! Use 'bquant'.");
}
}
else if(data_type == "i4f32bf8")
else if(data_type == "bf8i4")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
if(quant_mode == "aquant")
if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
"Unsupported quantization mode for this datatype! Use 'bquant'.");
}
}
else

View File

@@ -11,10 +11,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_PREFILL 1
#define CK_TILE_PIPELINE_DECODE 2
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -48,6 +44,13 @@ constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
#endif
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
@@ -81,6 +84,9 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
@@ -102,8 +108,22 @@ struct GemmConfigQuant : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
static constexpr int kBlockPerCu = 1;
template <typename PrecType>
struct GemmConfigRowColQuant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
@@ -122,8 +142,6 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 1;
static constexpr bool PreshuffleQuant = true;
};
@@ -206,25 +224,25 @@ auto create_args(int argc, char* argv[])
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
.insert("prec",
"fp8",
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr")
.insert("quant_mode", "aquant", "Choose aquant (default) or rowcol");
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1", "rotating count, defaults to 1")
.insert("quant_mode", "aquant", "Choose aquant (default), bquant or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -1,302 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <random>
#include <stdexcept>
#include "../00_shared/host_tensor_utils.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
{
if(t.get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t.get_lengths()[0];
int aqk_ = t.get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t AQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK_A = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
BDataType,
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using AQDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t AQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<AQDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
}
else
{
a_m_k.SetZero();
aq_m_aqk.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
}
else
{
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
}
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}

View File

@@ -1,296 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <random>
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename T>
auto shuffle_bq(const ck_tile::HostTensor<T>& t, int block_bq_k)
{
if(t.get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int n_ = t.get_lengths()[0];
int bqk_ = t.get_lengths()[1];
if(bqk_ % block_bq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a bqk of multiple times of block_bq_k.");
}
ck_tile::HostTensor<T> t_view({n_, bqk_ / block_bq_k, block_bq_k});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename BQDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename BQLayout,
typename DsLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& bq_bqk_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t BQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_BQ,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.bq_ptr = bq_bqk_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK_B = BQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_BQ = stride_BQ;
float ave_time = gemm_calc_bquant<GemmConfig,
ADataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
ADataType, // computeDatatype
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K +
sizeof(BQDataType) * BQK * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideBQ =" << stride_BQ
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " BQ_Type = " << DataTypeTraits<BQDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using BQDataType = typename TypeConfig::QDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t BQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_BQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<BQDataType> bq_bqk_n(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(bq_bqk_n);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(bq_bqk_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
bq_bqk_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<GemmConfig,
ADataType,
BDataType,
BQDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
BQLayout,
ck_tile::tuple<>,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
b_k_n_dev_buf,
bq_bqk_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
BQK,
stride_A,
stride_B,
stride_BQ,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}

View File

@@ -6,28 +6,21 @@
#include <stdexcept>
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
{
if(t.get_lengths().size() != 2)
if(t->get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t.get_lengths()[0];
int aqk_ = t.get_lengths()[1];
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t.begin(), t.end(), t_view.begin());
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
@@ -39,12 +32,13 @@ template <typename GemmConfig,
typename BQLayout,
typename CLayout,
uint32_t QuantGroupSize,
ck_tile::QuantType QuantMode>
ck_tile::QuantType QuantMode,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_dev_buf,
ck_tile::DeviceMem* aq_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::DeviceMem* bq_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
@@ -57,11 +51,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
int n_repeat,
bool flush_cache,
int rotating_count)
{
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_dev_buf.GetDeviceBuffer();
args.aq_ptr = (aq_dev_buf != nullptr) ? aq_dev_buf->GetDeviceBuffer() : nullptr;
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.bq_ptr = (bq_dev_buf != nullptr) ? bq_dev_buf->GetDeviceBuffer() : nullptr;
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
@@ -83,17 +79,23 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
BLayout,
CLayout,
QuantGroupSize,
QuantMode>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
QuantMode,
CDEElementWise>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(typename TypeConfig::ADataType) * M * K +
sizeof(typename TypeConfig::QDataType) * M * AQK +
sizeof(typename TypeConfig::BDataType) * N * K +
sizeof(typename TypeConfig::CDataType) * M * N;
if(aq_dev_buf != nullptr)
{
num_byte += sizeof(typename TypeConfig::QDataType) * M * AQK;
}
if(bq_dev_buf != nullptr)
{
num_byte += sizeof(typename TypeConfig::QDataType) * BQK;
num_byte += sizeof(typename TypeConfig::QDataType) * N * BQK;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -103,7 +105,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name;
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
std::cout << " StrideBQ =" << stride_BQ;
}
@@ -117,8 +120,12 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = "
<< (QuantMode == ck_tile::QuantType::AQuantGrouped ? "AQuantGrouped" : "RowColQuant")
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< (QuantMode == ck_tile::QuantType::AQuantGrouped
? "AQuantGrouped"
: (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped"
: "RowColQuant"))
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
@@ -156,7 +163,8 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if(K % QuantGroupSize != 0)
{
@@ -164,13 +172,17 @@ int run_gemm_example_with_layouts(int argc,
"K must be aligned with QuantGroupSize for AQuantGrouped mode");
}
}
ck_tile::index_t AQK, BQK;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
AQK = 1; // Row quantization: tensor shape [M, 1]
@@ -186,12 +198,14 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t stride_BQ = 0;
ck_tile::index_t stride_BQ = arg_parser.get_int("stride_q");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool flush_cache = arg_parser.get_bool("flush_cache");
int rotating_count = arg_parser.get_int("rotating_count");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -201,6 +215,12 @@ int run_gemm_example_with_layouts(int argc,
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
stride_BQ = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
stride_AQ = 0; // No A quantization
stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout));
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
@@ -216,21 +236,26 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// Create AQ tensor with appropriate shape
ck_tile::HostTensor<AQDataType> aq_tensor = [&]() {
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
return ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
else
{
return ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
}();
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
else if(QuantMode == ck_tile::QuantType::RowColQuant)
{
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
// Create BQ tensor only for RowColQuant mode
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
@@ -242,22 +267,41 @@ int run_gemm_example_with_layouts(int argc,
if(init_method == 0)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_tensor);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
}
else if(init_method == 1)
@@ -267,46 +311,65 @@ int run_gemm_example_with_layouts(int argc,
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_tensor);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
}
}
else
{
a_m_k.SetZero();
aq_tensor.SetZero();
aq_tensor_ptr->SetZero();
b_k_n.SetZero();
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr->SetZero();
}
bq_tensor_ptr->SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem aq_dev_buf(aq_tensor.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
aq_dev_buf_ptr =
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes());
}
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_dev_buf_ptr =
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
}
if constexpr(GemmConfig::PreshuffleQuant && QuantMode == ck_tile::QuantType::AQuantGrouped)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_tensor, GemmConfig::K_Tile / QuantGroupSize);
aq_dev_buf.ToDevice(aq_shuffle_host.data());
}
else
{
aq_dev_buf.ToDevice(aq_tensor.data());
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
}
else
{
aq_dev_buf_ptr->ToDevice(aq_tensor_ptr->data());
}
}
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
@@ -320,11 +383,22 @@ int run_gemm_example_with_layouts(int argc,
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
}
@@ -338,11 +412,16 @@ int run_gemm_example_with_layouts(int argc,
CLayout,
QuantGroupSize,
QuantMode>(a_m_k_dev_buf,
aq_dev_buf,
(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
? aq_dev_buf_ptr.get()
: nullptr,
b_k_n_dev_buf,
(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
? bq_dev_buf_ptr.get()
: nullptr,
c_m_n_dev_buf,
(QuantMode == ck_tile::QuantType::RowColQuant) ? bq_dev_buf_ptr.get()
: nullptr,
M,
N,
K,
@@ -355,7 +434,9 @@ int run_gemm_example_with_layouts(int argc,
stride_C,
kbatch,
n_warmup,
n_repeat);
n_repeat,
flush_cache,
rotating_count);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
@@ -374,7 +455,17 @@ int run_gemm_example_with_layouts(int argc,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_tensor, b_k_n, c_m_n_host_ref);
true>(a_m_k, *aq_tensor_ptr, b_k_n, c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
}
else
{
@@ -384,7 +475,7 @@ int run_gemm_example_with_layouts(int argc,
BQDataType,
AccDataType,
CDataType>(
a_m_k, aq_tensor, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
const float max_accumulated_value =

View File

@@ -70,17 +70,21 @@ struct get_bq_data_type_or<T, Default>
using type = typename T::BQDataType;
};
template <typename T, typename Default>
struct get_preshuffle_or
{
using type = Default;
template <typename T>
concept HasStaticPreshuffleQuant = requires {
{ T::PreshuffleQuant } -> std::convertible_to<decltype(T::PreshuffleQuant)>;
};
template <typename T, typename Default>
requires requires { typename T::PreshuffleQuant; }
struct get_preshuffle_or<T, Default>
template <typename T>
struct is_quantpreshuffle_enabled
{
using type = typename T::PreshuffleQuant;
static constexpr bool value = false;
};
template <HasStaticPreshuffleQuant T>
struct is_quantpreshuffle_enabled<T>
{
static constexpr auto value = T::PreshuffleQuant;
};
} // namespace detail
@@ -198,9 +202,9 @@ struct QuantGemmKernel
using BQLayout = remove_cvref_t<
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant = remove_cvref_t<
typename detail::get_preshuffle_or<GemmPipeline, std::false_type>::type>::value;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant =
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;