[CK_TILE] Row/Col quant gemm (#2729)

* Add cshuffle epilogue test

* add the poc implementation to the epilogue and tests

* refactor cshuffle epilogue

* WIP: adding tensor/tile usage to scale_tile

* fix usage of tile_elementwise_inout

* add gemm_quant_kernel for generalizing gemm quant kernel

* Add problem specific to different quants, add QuantType to Traits

* Add quant_type to quant_kernel template parameters

* Create aq/bq_block_windows and views depending on QuantType

* Use tile windows as inputs in cshuffle epilogue

* Fix some issues in epilogue

* initial new example code for new general gemm quant kernel test

* Fix issues in kernel

* Add verification check for rowcol Quantmode

* use AccDataType instead of AQ in pipeline

* fix aquant preshuffle

* fix formatting

* some cleanup

* remove gemm_aquant_basic.cpp

* remove gemm_aquant_kernel.hpp

* fix tests for the renamed quant kernel

* fix formatting

* clean example files

* fix some merge conflicts

* fix preshufflequant rename issue

* fix some templates after merging with develop

* fix test preshuffle parameter

* fix formatting

* Unify bquant kernel to the common quant kernel

* remove bquant kernel also from common header

* fix formatting

* clean up commented code

* fix formatting config hpp

* fix merge mistake

* Non-const for movable windows

* fix formatting

* Fix grammar in README

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Remove #include<bit> and clean up example

* fix strides

* Add some descriptions for move_windows

---------

Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

[ROCm/composable_kernel commit: c6010f2953]
This commit is contained in:
Sami Remes
2025-09-05 02:17:12 +03:00
committed by GitHub
parent ecc4a470ec
commit abf4f7a7b2
23 changed files with 1837 additions and 1331 deletions

View File

@@ -6,8 +6,12 @@ endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp)
target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp)
target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp)
target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -1,18 +1,21 @@
# GEMM Matrix Multiplication
# Quant GEMM Matrix Multiplication
This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation.
This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation.
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
- Row and Column-wise scaled: scaling implemented in Epilogue
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# The aquant pipeline method on the gemm calculation
make tile_example_gemm_aquant_basic -j
# Compile the quant kernels
make tile_example_gemm_quant_basic -j
make tile_example_gemm_bquant_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_aquant_basic`
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
## example
```
@@ -22,15 +25,16 @@ args:
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-b_layout Tensor B data layout (default: C)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-quant_mode Which quant method to use (aquant, rowcol)
```

View File

@@ -1,226 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = true;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
#include "run_gemm_aquant_example.inc"
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return 0;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }

View File

@@ -21,7 +21,7 @@ template <typename GemmConfig,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -50,13 +50,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::AQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -109,8 +110,10 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);

View File

@@ -23,7 +23,7 @@ template <typename GemmConfig,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_calc_bquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -50,13 +50,14 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::BQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -108,8 +109,10 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::BQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::BQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);

View File

@@ -0,0 +1,376 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
ck_tile::QuantType QuantMode>
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = GemmConfig::kPadM;
constexpr bool kPadN = GemmConfig::kPadN;
constexpr bool kPadK = GemmConfig::kPadK;
constexpr int kBlockPerCu = GemmConfig::kBlockPerCu;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
// B datatype is safe to use as compute type as it should be at least fp8
using ComputeDataType = typename TypeConfig::BDataType;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout,
QuantMode>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using CodegenPipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
using CodegenGemmPipeline =
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<CodegenPipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue, QuantMode>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
#include "run_gemm_quant_example.inc"
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return 0;
}
void print_help(const char* program_name)
{
std::cout
<< "Usage: " << program_name << " [OPTIONS]\n"
<< "\n Parameters:\n"
<< " -quant_mode=MODE aquant (A quantization) or rowcol (row/column quantization)\n"
<< " -v=LEVEL 0=No validation, 1=CPU validation, 2=GPU validation\n"
<< " -prec=TYPE Data types: fp8, bf8, i4fp8, i4bf8, i4f32fp8, i4f32bf8\n"
<< " -m, -n, -k=SIZE Matrix dimensions (M×K) @ (K×N) = (M×N)\n"
<< "\nExample:\n"
<< " " << program_name << " -quant_mode=rowcol -v=1 -prec=fp8\n"
<< std::endl;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
for(int i = 1; i < argc; ++i)
{
if(strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0)
{
print_help(argv[0]);
return 0;
}
}
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string quant_mode = arg_parser.get_str("quant_mode");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported quantization mode! Use 'aquant' or 'rowcol'");
}
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported quantization mode! Use 'aquant' or 'rowcol'");
}
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::fp8_t>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::bf8_t>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }

View File

@@ -83,14 +83,13 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigDecode : public GemmConfigBase
struct GemmConfigQuant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
@@ -105,29 +104,6 @@ struct GemmConfigDecode : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
};
template <typename PrecType>
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
};
template <typename PrecType>
@@ -148,9 +124,7 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
static constexpr bool PreshuffleQuant = true;
static constexpr bool PreshuffleQuant = true;
};
template <typename ADataType_,
@@ -237,19 +211,20 @@ auto create_args(int argc, char* argv[])
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr")
.insert("quant_mode", "aquant", "Choose aquant (default) or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -3,7 +3,6 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <bit>
#include <random>
#include <stdexcept>
@@ -59,7 +58,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::AQuantGemmHostArgs args;
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
@@ -68,7 +67,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.M = M;
args.N = N;
args.K = K;
args.QK = AQK;
args.QK_A = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;

View File

@@ -2,7 +2,6 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <bit>
#include <random>
template <typename Layout>
@@ -60,7 +59,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::BQuantGemmHostArgs args;
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.bq_ptr = bq_bqk_n_dev_buf.GetDeviceBuffer();
@@ -69,7 +68,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.M = M;
args.N = N;
args.K = K;
args.QK = BQK;
args.QK_B = BQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;

View File

@@ -0,0 +1,404 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <random>
#include <stdexcept>
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
{
if(t.get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t.get_lengths()[0];
int aqk_ = t.get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
uint32_t QuantGroupSize,
ck_tile::QuantType QuantMode>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::DeviceMem* bq_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t AQK,
ck_tile::index_t BQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_BQ,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.bq_ptr = (bq_dev_buf != nullptr) ? bq_dev_buf->GetDeviceBuffer() : nullptr;
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK_A = AQK;
args.QK_B = BQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
args.stride_BQ = stride_BQ;
float ave_time = gemm_calc_quant<GemmConfig,
TypeConfig,
ALayout,
BLayout,
CLayout,
QuantGroupSize,
QuantMode>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(typename TypeConfig::ADataType) * M * K +
sizeof(typename TypeConfig::QDataType) * M * AQK +
sizeof(typename TypeConfig::BDataType) * N * K +
sizeof(typename TypeConfig::CDataType) * M * N;
if(bq_dev_buf != nullptr)
{
num_byte += sizeof(typename TypeConfig::QDataType) * BQK;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name;
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
std::cout << " StrideBQ =" << stride_BQ;
}
std::cout << " A_Type = " << DataTypeTraits<typename TypeConfig::ADataType>::name
<< " AQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name
<< " B_Type = " << DataTypeTraits<typename TypeConfig::BDataType>::name;
if constexpr(!std::is_same_v<typename TypeConfig::QDataType, void>)
{
std::cout << " BQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name;
}
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = "
<< (QuantMode == ck_tile::QuantType::AQuantGrouped ? "AQuantGrouped" : "RowColQuant")
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
ck_tile::QuantType QuantMode,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using AQDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
using BQDataType = typename TypeConfig::QDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if(K % QuantGroupSize != 0)
{
throw std::runtime_error(
"K must be aligned with QuantGroupSize for AQuantGrouped mode");
}
}
ck_tile::index_t AQK, BQK;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
AQK = 1; // Row quantization: tensor shape [M, 1]
BQK = N; // Column quantization: tensor shape [1, N]
}
else
{
static_assert(false, "Unsupported QuantMode");
}
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t stride_BQ = 0;
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
// Conditional stride calculation based on QuantMode
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout));
}
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// Create AQ tensor with appropriate shape
ck_tile::HostTensor<AQDataType> aq_tensor = [&]() {
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
return ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
else
{
return ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
}();
// Create BQ tensor only for RowColQuant mode
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
}
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_tensor);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_tensor);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
}
else
{
a_m_k.SetZero();
aq_tensor.SetZero();
b_k_n.SetZero();
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr->SetZero();
}
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem aq_dev_buf(aq_tensor.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_dev_buf_ptr =
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
}
if constexpr(GemmConfig::PreshuffleQuant && QuantMode == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_tensor, GemmConfig::K_Tile / QuantGroupSize);
aq_dev_buf.ToDevice(aq_shuffle_host.data());
}
else
{
aq_dev_buf.ToDevice(aq_tensor.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
}
invoke_gemm<GemmConfig,
TypeConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
QuantGroupSize,
QuantMode>(a_m_k_dev_buf,
aq_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
(QuantMode == ck_tile::QuantType::RowColQuant) ? bq_dev_buf_ptr.get()
: nullptr,
M,
N,
K,
AQK,
BQK,
stride_A,
stride_AQ,
stride_B,
stride_BQ,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_tensor, b_k_n, c_m_n_host_ref);
}
else
{
ck_tile::reference_gemm_rowcol_quant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType>(
a_m_k, aq_tensor, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}

View File

@@ -115,6 +115,92 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
std::cout << std::endl;
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<AQDataType>& aq_m_1,
const HostTensor<BDataType>& b_k_n,
const HostTensor<BQDataType>& bq_1_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
// Init accumulator
AccDataType v_acc = 0;
// Get row scale for A and column scale for B
float a_scale = aq_m_1(m, 0);
float b_scale = bq_1_n(0, n);
// Compute the dot product
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
// Process A data
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
// Process B data
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_acc += v_a * v_b;
}
v_acc = v_acc * a_scale * b_scale;
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,

View File

@@ -6,6 +6,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <optional>
namespace ck_tile {
@@ -210,8 +213,12 @@ struct CShuffleEpilogue
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
@@ -257,11 +264,120 @@ struct CShuffleEpilogue
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
template <auto iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE void
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
{
// Load tiles
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
}
}
template <auto iAccess, typename OAccTile, typename LdsTile>
CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
{
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
}
template <typename LdsTile, typename InLdsWindow>
CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
{
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
}
template <typename DramWindows, typename COutTensor>
CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
{
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
}
template <typename OutDramWindow, typename COutTensor>
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
const COutTensor& c_out_tensor)
{
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
}
/**
* @brief Move both the output and D tensors windows for the next access.
*/
template <auto iAccess, typename OutDramWindow, typename DDramWindows>
CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
{
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
// move the output dram window
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
// move windows for each of the D matrices (inputs for element-wise)
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
});
}
}
// TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
struct EmptyScale
{
};
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
void* p_smem,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
@@ -282,9 +398,6 @@ struct CShuffleEpilogue
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
@@ -306,60 +419,46 @@ struct CShuffleEpilogue
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
auto scale_m_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
if constexpr(has_scales)
{
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
}
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
cast_lds_tile(lds_tile, in_lds_window);
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
apply_d_tensors(d_dram_windows, c_out_tensor);
store_to_dram(out_dram_window, c_out_tensor);
move_windows<iAccess>(out_dram_window, d_dram_windows);
});
}
};

View File

@@ -5,8 +5,7 @@
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"

View File

@@ -1,679 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
struct BQuantGemmProblem
{
CK_TILE_HOST BQuantGemmProblem() = default;
CK_TILE_HOST BQuantGemmProblem(index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_BQ_)
: M(M_),
N(N_),
K(K_),
QK(QK_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_C(stride_C_),
stride_BQ(stride_BQ_)
{
}
index_t M;
index_t N;
index_t K;
index_t QK;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_BQ;
};
struct BQuantGemmHostArgs : public BQuantGemmProblem
{
CK_TILE_HOST BQuantGemmHostArgs() = default;
CK_TILE_HOST BQuantGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_BQ_)
: BQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_BQ_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
bq_ptr(bq_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const void* bq_ptr;
void* c_ptr;
index_t k_batch;
};
struct BQuantGemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
const void* bq_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t QK;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_BQ;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BQuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using BQDataType = remove_cvref_t<typename GemmPipeline::BQDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr BQuantGemmKernelArgs
MakeKernelArgs(const BQuantGemmHostArgs& hostArgs)
{
return BQuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.bq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_BQ,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const BQuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
}
else
{
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const BQuantGemmKernelArgs& kargs)
{
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
}
return false;
}
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if(kargs.QK % GemmPipeline::GetVectorSizeBQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
return false;
}
}
else
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
}
return false;
}
}
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
const BQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
const auto& bq_tensor_view = [&]() {
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.N, kargs.QK),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, bq_tensor_view, b_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
const auto& bq_pad_view = [&]() {
const auto& bq_tensor_view = views.at(I1);
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return pad_tensor_view(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
return make_tuple(a_pad_view, bq_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& bq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& bq_block_window = [&]() {
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_n, 0});
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, bq_block_window, b_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const BQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& bq_block_window = gemm_tile_windows.at(I1);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(BQuantGemmKernelArgs kargs) const
{
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
RunGemm(a_ptr, b_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
};
} // namespace ck_tile

View File

@@ -12,117 +12,218 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
namespace ck_tile {
struct AQuantGemmProblem
namespace detail {
// Helper templates for safe type extraction
template <typename T, typename Default>
struct get_aq_layout_or
{
CK_TILE_HOST AQuantGemmProblem() = default;
CK_TILE_HOST AQuantGemmProblem(index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_)
using type = Default;
};
template <typename T, typename Default>
requires requires { typename T::AQLayout; }
struct get_aq_layout_or<T, Default>
{
using type = typename T::AQLayout;
};
template <typename T, typename Default>
struct get_bq_layout_or
{
using type = Default;
};
template <typename T, typename Default>
requires requires { typename T::BQLayout; }
struct get_bq_layout_or<T, Default>
{
using type = typename T::BQLayout;
};
template <typename T, typename Default>
struct get_aq_data_type_or
{
using type = Default;
};
template <typename T, typename Default>
requires requires { typename T::AQDataType; }
struct get_aq_data_type_or<T, Default>
{
using type = typename T::AQDataType;
};
template <typename T, typename Default>
struct get_bq_data_type_or
{
using type = Default;
};
template <typename T, typename Default>
requires requires { typename T::BQDataType; }
struct get_bq_data_type_or<T, Default>
{
using type = typename T::BQDataType;
};
template <typename T, typename Default>
struct get_preshuffle_or
{
using type = Default;
};
template <typename T, typename Default>
requires requires { typename T::PreshuffleQuant; }
struct get_preshuffle_or<T, Default>
{
using type = typename T::PreshuffleQuant;
};
} // namespace detail
struct QuantGemmProblem
{
CK_TILE_HOST QuantGemmProblem() = default;
CK_TILE_HOST QuantGemmProblem(index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_,
index_t stride_BQ_)
: M(M_),
N(N_),
K(K_),
QK(QK_),
QK_A(QK_A_),
QK_B(QK_B_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_C(stride_C_),
stride_AQ(stride_AQ_)
stride_AQ(stride_AQ_),
stride_BQ(stride_BQ_)
{
}
index_t M;
index_t N;
index_t K;
index_t QK;
index_t QK_A;
index_t QK_B;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_AQ;
index_t stride_BQ;
};
struct AQuantGemmHostArgs : public AQuantGemmProblem
struct QuantGemmHostArgs : public QuantGemmProblem
{
CK_TILE_HOST AQuantGemmHostArgs() = default;
CK_TILE_HOST AQuantGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
const void* aq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_)
: AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_),
CK_TILE_HOST QuantGemmHostArgs() = default;
CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
const void* aq_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_,
index_t stride_BQ_)
: QuantGemmProblem(
M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
aq_ptr(aq_ptr_),
bq_ptr(bq_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
void* c_ptr;
index_t k_batch;
const void* a_ptr = nullptr;
const void* b_ptr = nullptr;
const void* aq_ptr = nullptr;
const void* bq_ptr = nullptr;
void* c_ptr = nullptr;
index_t k_batch = 0;
};
struct AQuantGemmKernelArgs
struct QuantGemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
const void* bq_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t QK;
index_t QK_A;
index_t QK_B;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_AQ;
index_t stride_BQ;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct AQuantGemmKernel
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
QuantType QuantType_>
struct QuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using AQLayout = remove_cvref_t<
typename detail::get_aq_layout_or<GemmPipeline, typename GemmPipeline::ALayout>::type>;
using BQLayout = remove_cvref_t<
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant;
static constexpr bool PreshuffleQuant = remove_cvref_t<
typename detail::get_preshuffle_or<GemmPipeline, std::false_type>::type>::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
using AQDataType =
remove_cvref_t<typename detail::get_aq_data_type_or<GemmPipeline, AccDataType>::type>;
using BQDataType =
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
static constexpr auto I0 = number<0>(); // A Tensor
static constexpr auto I1 = number<1>(); // AQ Tensor
static constexpr auto I2 = number<2>(); // B Tensor
static constexpr auto I3 = number<3>(); // BQ Tensor
static constexpr auto I4 = number<4>(); // C Tensor
static constexpr auto kQuantType = QuantType_;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
}
@@ -133,22 +234,25 @@ struct AQuantGemmKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)
CK_TILE_HOST static constexpr QuantGemmKernelArgs
MakeKernelArgs(const QuantGemmHostArgs& hostArgs)
{
return AQuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.aq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.k_batch};
return QuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.aq_ptr,
hostArgs.bq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK_A,
hostArgs.QK_B,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.stride_BQ,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -158,7 +262,7 @@ struct AQuantGemmKernel
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
@@ -198,7 +302,7 @@ struct AQuantGemmKernel
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs)
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
{
if(kargs.k_batch != 1)
{
@@ -209,14 +313,31 @@ struct AQuantGemmKernel
return false;
}
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0)
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
}
return false;
}
}
// NOTE: no kernel currently uses BQuant like this:
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
}
return false;
}
return false;
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -350,8 +471,9 @@ struct AQuantGemmKernel
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
const AQuantGemmKernelArgs& kargs,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
@@ -380,71 +502,85 @@ struct AQuantGemmKernel
return ck_tile::integer_least_multiple(length, alignment) - length;
};
const auto& make_preshuffled_aq_tensor_view = [&]() {
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
const auto aq_y = kargs.QK / GemmPipeline::KPerBlockAQ;
const auto aq_desc =
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
make_tuple(aq_x, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
const auto aq_pad0_desc = transform_tensor_descriptor(
aq_desc,
make_tuple(make_pass_through_transform(aq_y),
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
aq_pad0_desc,
make_tuple(make_pass_through_transform(aq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto aq_pad1_desc = transform_tensor_descriptor(
aq_unmerge_pad0_desc,
make_tuple(make_pass_through_transform(aq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto pad_wave_size =
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
};
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
return make_preshuffled_aq_tensor_view();
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
const auto aq_desc =
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
make_tuple(aq_x, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
const auto aq_pad0_desc = transform_tensor_descriptor(
aq_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
aq_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto aq_pad1_desc = transform_tensor_descriptor(
aq_unmerge_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto pad_wave_size =
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
}
else
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, 0), // broadcasting over n
number<1>{},
number<1>{});
}
else
{
return nullptr; // TODO: use some other "empty" type for this
}
}();
const auto& b_tensor_view = [&]() {
@@ -510,6 +646,32 @@ struct AQuantGemmKernel
}
}();
const auto& bq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(0, 1), // broadcasting over m
number<1>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.N, kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
return nullptr; // TODO: use some other "empty" type for this
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
@@ -532,7 +694,8 @@ struct AQuantGemmKernel
}
}();
return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
return make_tuple(
a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
}
template <typename TensorView>
@@ -556,6 +719,7 @@ struct AQuantGemmKernel
}
}();
// no padding
const auto& aq_pad_view = [&]() { return views.at(I1); }();
const auto& b_pad_view = [&]() {
@@ -576,9 +740,12 @@ struct AQuantGemmKernel
}
}();
// no padding
const auto& bq_pad_view = [&]() { return views.at(I3); }();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
const auto& c_tensor_view = views.at(I4);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
@@ -595,7 +762,7 @@ struct AQuantGemmKernel
}
}();
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
}
template <typename PadView>
@@ -605,7 +772,8 @@ struct AQuantGemmKernel
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -625,14 +793,13 @@ struct AQuantGemmKernel
}();
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block =
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
if constexpr(PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block =
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
@@ -642,13 +809,27 @@ struct AQuantGemmKernel
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return nullptr; // TODO: use some other "empty" type?
}
}();
const auto& b_block_window = [&]() {
@@ -668,12 +849,36 @@ struct AQuantGemmKernel
}
}();
const auto& bq_block_window = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(bq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_n, 0});
}
else
{
return nullptr; // TODO: use some other "empty" type here
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window);
return make_tuple(
a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
}
/**
@@ -695,16 +900,17 @@ struct AQuantGemmKernel
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const AQuantGemmKernelArgs& kargs,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
@@ -713,22 +919,51 @@ struct AQuantGemmKernel
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& aq_block_window = gemm_tile_windows.at(I1);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
const auto& c_block_tile = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
const auto& aq_block_window = gemm_tile_windows.at(I1);
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(I3);
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
}
}();
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
auto& c_block_window = gemm_tile_windows.at(I4);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
if constexpr(kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
const auto& aq_block_window = gemm_tile_windows.at(I1);
const auto& bq_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
}
CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
{
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
@@ -740,13 +975,15 @@ struct AQuantGemmKernel
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
RunGemm(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
};

View File

@@ -14,6 +14,7 @@ namespace ck_tile {
template <typename ADataType_,
typename AQDataType_,
typename BDataType_,
typename BQDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
@@ -23,12 +24,12 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
@@ -44,6 +45,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
using typename Base::CDataType;
using typename Base::ComputeDataType;
using AQDataType = remove_cvref_t<AQDataType_>;
using BQDataType = remove_cvref_t<BQDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
@@ -63,6 +65,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
using Base::VectorLoadSize;
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
static constexpr auto Scheduler = Scheduler_;
@@ -75,7 +78,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_aquant_problem",
return concat('_', "gemm_quant_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
@@ -94,6 +97,13 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return kPadK ? 1 : GetAlignmentAQ();
}();
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
{
return VectorLoadSize / sizeof(BQDataType);
}
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
};
template <typename ADataType_,
@@ -108,18 +118,19 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
AQDataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
AQDataType_,
BDataType_,
void, // no BQDataType for AQuant
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
template <typename ADataType_,
typename BDataType_,
@@ -132,96 +143,42 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmBQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
using Traits = typename Base::Traits;
using typename Base::ADataType;
using typename Base::BDataType;
using typename Base::CDataType;
using typename Base::ComputeDataType;
using BQDataType = remove_cvref_t<BQDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
using typename Base::ALayout;
using typename Base::BLayout;
using typename Base::CLayout;
static constexpr bool TransposeC = Traits::TransposeC;
using Base::kBlockSize;
using Base::kPadK;
using Base::kPadM;
using Base::kPadN;
using Base::DoubleSmemBuffer;
using Base::VectorLoadSize;
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_bquant_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
"QuantGroupSize",
kQuantGroupSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
{
return VectorLoadSize / sizeof(BQDataType);
}
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
};
using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
void, // no AQDataType for BQuant
BDataType_,
BQDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
false, // no TransposeC
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
template <typename ADataType_,
typename BDataType_,
typename BQDataType_,
typename CDataType_,
typename AccDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename ComputeDataType_ = ADataType_,
bool TransposeC_ = false,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmBQuantPipelineProblem = GemmBQuantPipelineProblemBase<ADataType_,
BDataType_,
BQDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile

View File

@@ -4,9 +4,17 @@
#pragma once
#include "ck_tile/core.hpp"
#include <cstdint>
namespace ck_tile {
enum struct QuantType : std::uint16_t
{
AQuantGrouped = 0,
BQuantGrouped = 1,
RowColQuant = 2
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
@@ -14,19 +22,24 @@ template <bool kPadM_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
typename AQLayout_ = ALayout_>
struct TileGemmAQuantTraits
QuantType QuantType_,
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_>
struct TileGemmQuantTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr QuantType kQuantType = QuantType_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using AQLayout = AQLayout_;
using BQLayout = BQLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
@@ -35,31 +48,4 @@ struct TileGemmAQuantTraits
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool PreshuffleQuant_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
typename BQLayout_ = BLayout_>
struct TileGemmBQuantTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using BQLayout = BQLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
} // namespace ck_tile

View File

@@ -23,4 +23,5 @@ add_subdirectory(add_rmsnorm2d_rdquant)
add_subdirectory(gemm_block_scale)
add_subdirectory(utility)
add_subdirectory(reduce)
add_subdirectory(epilogue)
add_subdirectory(atomic_add_op)

View File

@@ -0,0 +1 @@
add_gtest_executable(test_ck_tile_cshuffle_epilogue test_cshuffle_epilogue.cpp)

View File

@@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_cshuffle_epilogue_util.hpp"
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
using namespace ck_tile;
class CShuffleEpilogueTest : public ::testing::Test
{
protected:
void SetUp() override {}
};
TEST_F(CShuffleEpilogueTest, BasicHalfTest)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t;
constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
ODataType,
kMPerBlock,
kNPerBlock,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl>;
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>();
EXPECT_TRUE(result) << "Basic CShuffleEpilogue test failed";
}
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t;
constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
ODataType,
kMPerBlock,
kNPerBlock,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl>;
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(true);
EXPECT_TRUE(result) << "Scale CShuffleEpilogue test failed";
}
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,191 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <iostream>
#include <memory>
#include <numeric>
#include <random>
#include <vector>
#include <hip/hip_runtime.h>
namespace ck_tile {
// Simple test kernel to invoke the CShuffleEpilogue
template <typename Problem, index_t M, index_t N, bool UseScale>
__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data,
float* m_scale,
float* n_scale)
{
using Epilogue = CShuffleEpilogue<Problem>;
static_assert(Problem::kMPerBlock <= M && Problem::kNPerBlock <= N,
"Block size must fit in tensor dimensions");
// Allocate shared memory for epilogue
__shared__ char smem[Epilogue::GetSmemSize()];
// Create accumulator tile
constexpr auto lds_distribution_encode =
make_static_tile_distribution(Epilogue::MakeLdsDistributionEncode());
auto acc_tile =
make_static_distributed_tensor<typename Epilogue::AccDataType>(lds_distribution_encode);
// Fill acc_tile with a simple pattern
auto& acc_buffer = acc_tile.get_thread_buffer();
acc_buffer[0] = 2.0F;
// Create output tensor view
auto output_tensor_view =
make_naive_tensor_view<address_space_enum::global>(output_data,
make_tuple(M, N),
make_tuple(N, 1),
number<Epilogue::GetVectorSizeC()>{},
number<1>{});
// Create output tile window
auto output_tile_window =
make_tile_window(output_tensor_view,
make_tuple(number<Problem::kMPerBlock>{}, number<Problem::kNPerBlock>{}),
{0, 0});
// Create empty D tensors tuple (we're ignoring ds_dram_windows for this test)
auto empty_ds = make_tuple();
// Call the epilogue
if constexpr(UseScale)
{
const auto m_scale_window = make_tile_window(
make_naive_tensor_view<address_space_enum::global>(
m_scale, make_tuple(M, N), make_tuple(1, 0), number<1>{}, number<1>{}),
make_tuple(number<Problem::kMPerBlock>{}, number<Problem::kNPerBlock>{}),
{0, 0});
const auto n_scale_window = make_tile_window(
make_naive_tensor_view<address_space_enum::global>(
n_scale, make_tuple(M, N), make_tuple(0, 1), number<1>{}, number<1>{}),
make_tuple(number<Problem::kMPerBlock>{}, number<Problem::kNPerBlock>{}),
{0, 0});
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, m_scale_window, n_scale_window);
}
else
{
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem);
}
}
// Test configuration helper
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename ODataType,
index_t kM,
index_t kN,
index_t MWave,
index_t NWave,
index_t MPerXdl,
index_t NPerXdl,
index_t KPerXdl>
using SimpleCShuffleEpilogueProblem =
CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>, // Empty Ds datatype tuple
AccDataType,
ODataType,
ck_tile::tuple<>, // Empty Ds layout
tensor_layout::gemm::RowMajor, // ELayout
ck_tile::element_wise::PassThrough, // CDElementwise
kM,
kN,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl,
false, // isCTransposed,
memory_operation_enum::set>;
template <typename Problem, index_t M, index_t N>
bool run_cshuffle_epilogue_test(bool use_scale = false)
{
using ODataType = typename Problem::ODataType;
constexpr index_t kMPerBlock = Problem::kMPerBlock;
constexpr index_t kNPerBlock = Problem::kNPerBlock;
constexpr index_t kBlockSize = Problem::kBlockSize;
std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N
<< ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock
<< ", BlockSize=" << kBlockSize << std::endl;
// Allocate host memory
const size_t output_size = M * N;
std::vector<ODataType> host_output(output_size, static_cast<ODataType>(0));
// Allocate device memory
ODataType* device_output;
HIP_CHECK_ERROR(hipMalloc(&device_output, output_size * sizeof(ODataType)));
HIP_CHECK_ERROR(hipMemcpy(
device_output, host_output.data(), output_size * sizeof(ODataType), hipMemcpyHostToDevice));
// Launch kernel
dim3 gridSize(1, 1, 1);
dim3 blockSize(kBlockSize, 1, 1);
if(use_scale)
{
float* m_scale;
float* n_scale;
std::vector<float> h_m_scale(M, 1.0F);
std::vector<float> h_n_scale(N, 1.0F);
h_n_scale[1] = 2.0F; // multiply one col only with 2
HIP_CHECK_ERROR(hipMalloc(&m_scale, M * sizeof(float)));
HIP_CHECK_ERROR(hipMalloc(&n_scale, N * sizeof(float)));
HIP_CHECK_ERROR(
hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(
hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice));
test_cshuffle_epilogue_kernel<Problem, M, N, true>
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
}
else
{
test_cshuffle_epilogue_kernel<Problem, M, N, false>
<<<gridSize, blockSize>>>(device_output, nullptr, nullptr);
}
// Check for kernel launch errors
HIP_CHECK_ERROR(hipGetLastError());
HIP_CHECK_ERROR(hipDeviceSynchronize());
// Copy results back
HIP_CHECK_ERROR(hipMemcpy(
host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost));
// Basic verification - just check that output has a 2, and 4 if using scaling
bool has_2 =
type_convert<float>(host_output[0]) > 1.9F && type_convert<float>(host_output[0]) < 2.1F;
bool scale_has_4 = true;
if(use_scale)
{
scale_has_4 = type_convert<float>(host_output[1]) > 3.9F &&
type_convert<float>(host_output[1]) < 4.1F;
}
// Cleanup
HIP_CHECK_ERROR(hipFree(device_output));
return has_2 && scale_has_4;
}
} // namespace ck_tile

View File

@@ -240,4 +240,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -26,7 +26,7 @@ template <typename GemmConfig,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -55,13 +55,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
false, // preshuffle
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::AQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -114,8 +115,10 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -185,7 +188,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::AQuantGemmHostArgs args;
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
@@ -194,7 +197,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.M = M;
args.N = N;
args.K = K;
args.QK = AQK;
args.QK_A = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;