[CK_TILE] Implement Row/Col quant grouped gemm (#2786)

* Add cshuffle epilogue test

* add the poc implementation to the epilogue and tests

* refactor cshuffle epilogue

* WIP: adding tensor/tile usage to scale_tile

* fix usage of tile_elementwise_inout

* add gemm_quant_kernel for generalizing gemm quant kernel

* Add problem specific to different quants, add QuantType to Traits

* Add quant_type to quant_kernel template parameters

* Create aq/bq_block_windows and views depending on QuantType

* Use tile windows as inputs in cshuffle epilogue

* Fix some issues in epilogue

* initial new example code for new general gemm quant kernel test

* Fix issues in kernel

* Add verification check for rowcol Quantmode

* use AccDataType instead of AQ in pipeline

* fix aquant preshuffle

* fix formatting

* some cleanup

* remove gemm_aquant_basic.cpp

* remove gemm_aquant_kernel.hpp

* fix tests for the renamed quant kernel

* fix formatting

* clean example files

* fix some merge conflicts

* fix preshufflequant rename issue

* updating

* fix some templates after merging with develop

* fix test preshuffle parameter

* fix formatting

* updating kernels

* change update user

* test username

* update quant_grouped_gemm example

* update example

* Unify bquant kernel to the common quant kernel

* remove bquant kernel also from common header

* fix formatting

* clean up commented code

* update grouped_gemm_quant example

* fix formatting config hpp

* fix merge mistake

* Non-const for movable windows

* fix formatting

* update tileloop pipleline

* Fix grammar in README

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Remove #include<bit> and clean up example

* fix strides

* Add some descriptions for move_windows

* fix tensor print bug

* update quant_grouped_gemm example

* remove useless code

* cleanup code

* clean up code & format code

* fix compile & running bug in grouped_gemm example

---------

Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: liyingli <liyingli@amd.com>
Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

[ROCm/composable_kernel commit: 4eb415829e]
This commit is contained in:
kyle-256
2025-09-09 01:25:57 +08:00
committed by GitHub
parent 8f02933795
commit f1aec610f1
12 changed files with 1225 additions and 13 deletions

View File

@@ -1,2 +1,3 @@
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)

View File

@@ -175,6 +175,8 @@ mkdir build && cd build
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`

View File

@@ -321,6 +321,36 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);

View File

@@ -0,0 +1,136 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::RowColQuant;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::DoubleSmemBuffer,
true>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return ave_time;
}
#include "quant_run_grouped_gemm_example.inc"
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
}

View File

@@ -0,0 +1,157 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("Ns", "", "N dimensions - empty by default.")
.insert("Ks", "", "K dimensions - empty by default.")
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk = false);

View File

@@ -0,0 +1,443 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
{
// Workspace memory allocated to hold the gemm descriptions.
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = 0;
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
assert(args[0].k_batch == 1);
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType>(stream, group_count, kargs_ptr);
std::string op_name{"Grouped Gemm"};
std::size_t flop = 0, num_btype = 0;
for(int j = 0; j < group_count; ++j)
{
flop += std::size_t(2) * args[j].M * args[j].N * args[j].K;
num_btype += sizeof(ADataType) * args[j].M * args[j].K +
sizeof(BDataType) * args[j].K * args[j].N +
sizeof(CDataType) * args[j].M * args[j].N;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_grouped_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
bool validate = arg_parser.get_bool("validate");
if(kbatch > 1 && validate && warmup + repeat > 1)
{
std::cout << "WARNING: Data validation enabled with SplitK and more than"
<< "1 warmup/repeat. Disabling validation." << std::endl;
validate = false;
}
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
std::vector<ck_tile::index_t> stride_AQs = arg_parser.get_int_vec("stride_AQs");
std::vector<ck_tile::index_t> stride_BQs = arg_parser.get_int_vec("stride_BQs");
ck_tile::index_t AQK, BQK;
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
stride_As.push_back(0);
stride_Bs.push_back(0);
stride_Cs.push_back(0);
stride_AQs.push_back(0);
stride_BQs.push_back(0);
}
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
c_m_n_tensors.reserve(group_count);
aq_tensors.reserve(group_count);
bq_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
aq_dev_buf.reserve(group_count);
bq_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
AQK = 1; // Row quantization: tensor shape [M, 1]. Only for NT
BQK = N; // Column quantization: tensor shape [1, N]. Only for NT
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
stride_BQs[i] = ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(1, N, stride_BQs[i], is_row_major(bq_layout))));
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
aq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(aq_tensors[i].get_element_space_size_in_bytes()));
bq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back({p_a,
p_b,
p_c,
p_aq,
p_bq,
kbatch,
M,
N,
K,
AQK,
BQK,
stride_As[i],
stride_Bs[i],
stride_Cs[i],
stride_AQs[i],
stride_BQs[i]});
}
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout>(warmup, repeat, group_count, gemm_descs);
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
}
bool pass{true};
if(validate)
{
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_rowcol_quant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
Ks[i], kbatch, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
using AQDataType = typename Types::AccDataType;
using BQDataType = typename Types::AccDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Row{}, Row{}, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}

View File

@@ -210,7 +210,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));

View File

@@ -162,9 +162,15 @@ CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
printf("element_space_size_: %ld,\n",
static_cast<long>(descriptor.get_element_space_size().value));
using Base = typename tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>::Base;
print(static_cast<const Base&>(descriptor));
printf("element_space_size_: %ld,\n", static_cast<long>(descriptor.get_element_space_size()));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"

View File

@@ -769,12 +769,11 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{

View File

@@ -0,0 +1,433 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/host.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
/// @brief The Grouped GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
struct QuantGroupedGemmHostArgs
{
CK_TILE_HOST QuantGroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* e_ptr_,
const void* aq_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_E_,
index_t stride_AQ_,
index_t stride_BQ_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
aq_ptr(aq_ptr_),
bq_ptr(bq_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
QK_A(QK_A_),
QK_B(QK_B_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_AQ(stride_AQ_),
stride_BQ(stride_BQ_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
const void* bq_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t QK_A;
index_t QK_B;
index_t stride_A;
index_t stride_B;
index_t stride_AQ;
index_t stride_BQ;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
using QuantGroupedGemmKernelArgs = QuantGemmKernelArgs;
struct QuantGemmTransKernelArg
{
QuantGroupedGemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
QuantGemmTransKernelArg() = delete;
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg)
: group_karg{karg}, block_start{0}, block_end{0}
{
}
};
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
QuantType QuantType_>
struct QuantGroupedGemmKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using Base = QuantGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
//// @brief Specify the layout configurations for A, B, C/E
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, C/E
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
using AQDataType =
remove_cvref_t<typename detail::get_aq_data_type_or<GemmPipeline, AccDataType>::type>;
using BQDataType =
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
static constexpr auto kQuantType = QuantType_;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Kernel =
QuantGroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline, kQuantType>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
// clang-format on
}
CK_TILE_HOST static auto
GetWorkSpaceSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs) -> std::size_t
{
return gemm_descs.size() * sizeof(QuantGemmTransKernelArg);
}
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
{
return group_count * sizeof(QuantGemmTransKernelArg);
}
CK_TILE_HOST static auto BlockSize() -> dim3
{
if(is_wave32())
{
return dim3(kBlockSize / 2);
}
else
{
return dim3(kBlockSize);
}
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
int occupancy;
HIP_CHECK_ERROR(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_func, kBlockSize, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto GridSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
{
const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
grid_size += local_grid_size * it_desc.k_batch;
}
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto MakeKargs(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
-> std::vector<QuantGemmTransKernelArg>
{
std::vector<QuantGemmTransKernelArg> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
index_t grid_size = 0;
gemm_kernel_args_.reserve(group_count);
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
const index_t M = gemm_descs[i].M;
const index_t N = gemm_descs[i].N;
const index_t K = gemm_descs[i].K;
if(M == 0 || N == 0 || K == 0)
{
continue;
}
const index_t stride_a = gemm_descs[i].stride_A;
const index_t stride_b = gemm_descs[i].stride_B;
const index_t stride_e = gemm_descs[i].stride_C;
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp;
grid_size += grid_size_grp;
auto karg =
QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].e_ptr),
type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
gemm_descs[i].k_batch,
M,
N,
K,
gemm_descs[i].QK_A,
gemm_descs[i].QK_B,
stride_a,
stride_b,
stride_e,
gemm_descs[i].stride_AQ,
gemm_descs[i].stride_BQ};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
return gemm_kernel_args_;
}
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<QuantGemmTransKernelArg>& kargs)
{
for(const auto& karg : kargs)
{
if(!Base::IsSupportedArgument(karg.group_karg))
{
return false;
}
}
return true;
}
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
const auto [iM, iN] = block_idx_2d;
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
static_assert(GemmPipeline::DoubleSmemBuffer == false,
"DoubleSmemBuffer needs to be false");
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
RunGemmWithPipelineSelection(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
* and the tail-number. This is needed for the persistent tile-loop when
* we didn't have access to the K dimension on the host.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
}
// For persistent kernels
template <bool U = UsePersistentKernel,
typename = std::enable_if_t<U>,
typename = void> // extra template parameter to avoid redefinition
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const
{
const index_t grid_size = ck_tile::get_grid_size();
const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
index_t cum_grid_size = 0;
for(index_t group_id = 0; group_id < group_count; ++group_id)
{
const auto& kargs = gemm_desc_ptr[group_id].group_karg;
const auto& k_batch = kargs.k_batch;
const auto block_start = cum_grid_size;
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
while(block_id < cum_grid_size)
{
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)
{
break; // exit the loop if all blocks are processed
}
}
}
}
};
} // namespace ck_tile

View File

@@ -23,8 +23,10 @@ template <bool kPadM_,
typename BLayout_,
typename CLayout_,
QuantType QuantType_,
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_>
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_,
bool DoubleSmemBuffer_ = false,
bool UsePersistentKernel_ = false>
struct TileGemmQuantTraits
{
static constexpr bool kPadM = kPadM_;
@@ -33,7 +35,8 @@ struct TileGemmQuantTraits
static constexpr QuantType kQuantType = QuantType_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = 16;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using ALayout = ALayout_;
using BLayout = BLayout_;
@@ -44,6 +47,7 @@ struct TileGemmQuantTraits
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};