ck_tile kernel for gemm with groupwise quantized A tensor (#2473)

* ck_tile kernel for gemm with groupwise quantized A or B tensor.

This change introduces new pipelines with Intrawave scheduler and block gemm primitives that loads the scale tensor to registers to perform dequantization post MFMA on C tensor in registers.

Scale tensor data, AQ/BQ is spliced across threads in registers and not stored in LDS.

Current support is for the following combinations, but it should be fairly straightforward to extend support to more formats.

1. fp8, fp8 -> f32
2. bf8, bf8 -> f32
3. i4, fp8 -> f32
4. i4, bf8 -> f32

Group size can go down to as low as K length of underlying WarpGemm primitive.

For Gemm problems with quantized B tensor, this change also introduces preliminary support for flatmm pipeline which loads B tensor directly into registers.

* [Block Scale Gemm] Only run gemm quant examples on __gfx94__

- Only run gemm quant examples on __gfx94__ for usage of
  `v_cvt_pk_fp8_f32`
- Format the code

* [Block Scale Gemm] Remove Bquant Gemm BlockScale

This cleanup is in preparation for future development of bquant. By
isolating Aquant-related code, we can streamline the codebase and make
it easier to add and maintain bquant functionality in subsequent
updates.

* [Block Scale Gemm] Format code with clang-format-12

The latest clang-format (v19) in ROCm 7.0 generate different result than
clang-format-12 which is used in CK CI.

Format code with clang-format-12 for consistency.

* [Block Scale Gemm] Split the k direction loop

- Split the k direction loop in block_universal_gemm_as_quant_bs_cr.hpp
   to make the logic clearer.
- Disable C transposition.

* [Block Scale Gemm] Move block scale gemm example to 38_block_scale_gemm

* [Block Scale Gemm] Update copyright

* test

* Add TailHandler

* Move TileDistributionEncodingPatternAQ

* Refactor

* refactor

* fix bug

* fix bug

* help solve the PR comment

* Format the code

* [Block Scale Gemm] Add unit tests

* [Block Scale Gemm] Add support to 16x16x32 MFMA

- Add support to 16x16x32 MFMA
- Fix a bug when exchange data crossing lanes

---------

Co-authored-by: Vijay Krishnamoorthy <vjkrish@meta.com>
Co-authored-by: Cong MA <congma13@ctr2-alola-ctrl-01.amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>

[ROCm/composable_kernel commit: e62710e461]
This commit is contained in:
Cong Ma
2025-07-23 01:10:16 -06:00
committed by GitHub
parent 23e4294747
commit baf244000e
35 changed files with 4864 additions and 13 deletions

View File

@@ -114,16 +114,16 @@ template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;

View File

@@ -0,0 +1,13 @@
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp)
target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -0,0 +1,35 @@
# GEMM Matrix Multiplication
This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The aquant pipeline method on the gemm calculation
make tile_example_gemm_aquant_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_aquant_basic`
## example
```
args:
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```

View File

@@ -0,0 +1,226 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
#include "run_gemm_aquant_example.inc"
template <typename TypeConfig, uint32_t QuantGroupSize>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return 0;
}
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }

View File

@@ -0,0 +1,675 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(__gfx950__)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(__gfx950__)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
{
using ADataType = ck_tile::int8_t;
using BDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using CDataType = int32_t;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
typename QDataType_ = float>
struct GemmQuantTypeConfig
{
using ADataType = ADataType_;
using QDataType = QDataType_;
using BDataType = BDataType_;
using AccDataType = float;
using CDataType = CDataType_;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using QDataType = float;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,259 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <bit>
#include <random>
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t AQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::AQuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
BDataType,
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using AQDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t AQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<AQDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
}
else
{
a_m_k.SetZero();
aq_m_aqk.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}

View File

@@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_copy)
add_subdirectory(37_transpose)
add_subdirectory(38_block_scale_gemm)

View File

@@ -116,6 +116,24 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0);
float x_h = ((x_u8 & 0xf0) >> 4);
x_l = x_l > 7 ? x_l - 16 : x_l;
x_h = x_l > 7 ? x_l - 16 : x_l;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t res = {x_h, x_l};
#elif
fp32x2_t res = {x_l, x_h};
#endif
return res;
}
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);

View File

@@ -8,6 +8,7 @@
#include <iterator>
#include <optional>
#include <random>
#include <stdexcept>
#include <type_traits>
#include <utility>
#include <unordered_set>
@@ -92,6 +93,60 @@ struct FillUniformDistribution
}
};
template <>
struct FillUniformDistribution<ck_tile::pk_int4_t>
{
float a_{-8.f}; // same type as primary template so that
// `FillUniformDistribution<Type>{-5.0f, 5.0f}` works for all types
float b_{7.f};
std::optional<uint32_t> seed_{11939};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
if(a_ < -8.0f || b_ > 7.0f)
{
throw std::runtime_error(
"a_ or b_ of FillUniformDistribution<ck_tile::pk_int4_t> is out of range.");
}
int min_value = static_cast<int>(a_);
int max_value = static_cast<int>(b_);
constexpr auto int4_array = std::array<uint8_t, 16>{0x88,
0x99,
0xaa,
0xbb,
0xcc,
0xdd,
0xee,
0xff,
0x00,
0x11,
0x22,
0x33,
0x44,
0x55,
0x66,
0x77};
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_int_distribution<std::int32_t> dis(0, max_value - min_value + 1);
while(first != last)
{
int randomInt = dis(gen);
*first = int4_array[randomInt + (min_value + 8)];
++first;
}
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
namespace impl {
// clang-format off

View File

@@ -11,6 +11,110 @@
namespace ck_tile {
template <typename ADataType,
typename QDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
uint32_t QuantGroupSize,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0, v_block_acc = 0;
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, pk_int4_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> ||
std::is_same_v<CDataType, ck_tile::half_t>);
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_block_acc += v_a * v_b;
// Apply group dequant scale
if((k + 1) % QuantGroupSize == 0)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
{
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
v_block_acc *= scale;
v_acc += v_block_acc;
v_block_acc = 0;
}
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,

View File

@@ -110,6 +110,86 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
return res;
}
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
{
uint32_t src = static_cast<uint32_t>(a), src_hi;
uint32_t fp8x4_lo, fp8x4_hi;
float tmp_0, tmp_1;
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_hi_src] "+v"(src_hi),
[v_dst_lo] "+v"(fp8x4_lo),
[v_dst_hi] "+v"(fp8x4_hi),
[v_src] "+v"(src)
:);
return bit_cast<fp8x8_t>(((static_cast<uint64_t>(fp8x4_hi) << 32) | fp8x4_lo));
}
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
{
float res;
asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
return res;
}
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
{
float res;
asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
return res;
}
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a)
{
uint32_t src = static_cast<uint32_t>(a), src_hi;
uint32_t bf8x4_lo, bf8x4_hi;
float tmp_0, tmp_1;
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_hi_src] "+v"(src_hi),
[v_dst_lo] "+v"(bf8x4_lo),
[v_dst_hi] "+v"(bf8x4_hi),
[v_src] "+v"(src)
:);
return bit_cast<bf8x8_t>(((static_cast<uint64_t>(bf8x4_hi) << 32) | bf8x4_lo));
}
struct PassThroughPack8
{
template <typename Y, typename X>
@@ -126,6 +206,16 @@ struct PassThroughPack8
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
y = amd_assembly_i4_to_fp8x8(bit_cast<int>(x));
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
y = amd_assembly_i4_to_bf8x8(bit_cast<int>(x));
}
constexpr const static bool is_pack8_invocable = true;
};

View File

@@ -69,6 +69,8 @@ struct CShuffleEpilogue
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
@@ -201,7 +203,7 @@ struct CShuffleEpilogue
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmMfmaDispatcher<ADataType,
using WG = WarpGemmMfmaDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,

View File

@@ -29,6 +29,9 @@ struct TileFlatmmShape
static constexpr index_t flatKPerWarp = WarpTile::at(idxK) * WarpTile::at(idxN);
static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(idxK);
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
CK_TILE_HOST static std::string GetName()
{
// clang-format off

View File

@@ -162,9 +162,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}

View File

@@ -37,6 +37,7 @@ struct WarpGemmAtrributeMfma
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
static constexpr index_t kCMLane = Impl::kCMLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }

View File

@@ -11,9 +11,10 @@ struct WarpGemmImpl
{
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK;
static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK;
static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).

View File

@@ -0,0 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp"

View File

@@ -0,0 +1,489 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
template <typename Problem, index_t UnaryOpSize_ = 8>
struct BlockGemmQuantBase
{
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
static constexpr index_t UnaryOpSize = UnaryOpSize_;
template <typename T>
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
{
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
}
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
}
else if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_f = ck_tile::bit_cast<float>(scale);
}
else
{
static_assert(false, "AQDataType must be float, fp8_t or bf8_t.");
}
return scale_reg_f;
}
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
{
const element_wise::PassThroughPack8 elementwise_op{};
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};
// A is block window on shared memory
// AQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
{
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
// Threadblock GEMM tile size
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
// number of warps along M and N for threadblock's GEMM problem size
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
using I0 = number<0>;
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
static constexpr index_t QScalesPerWarpGemmRow =
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / kQuantGroupSize > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
"Error! Warps should cover all Block tile!");
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
"Error! Warps should cover all Block tile!");
// Currently tested combinations (A, AQ, B)
// 1. fp8, fp32, fp8 -> f32
// 2. bf8, fp32, bf8 -> f32
// 3. i4, (fp8/fp32) fp8 -> f32
// 4. i4, (fp8/fp32) bf8 -> f32
static_assert(
(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<
ADataType,
bf8_t>)&&(std::is_same_v<BDataType, fp8_t> ||
std::is_same_v<
BDataType,
bf8_t>)&&(std::is_same_v<AQDataType, float> ||
std::is_same_v<AQDataType, ck_tile::fp8_t> ||
std::is_same_v<
AQDataType,
ck_tile::bf8_t>)&&(std::is_same_v<ComputeDataType,
fp8_t> ||
std::is_same_v<ComputeDataType,
bf8_t>)&&std::
is_same_v<CDataType, fp32_t>);
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using AQDataType = remove_cvref_t<typename Traits::AQDataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using Base = BlockGemmQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
private:
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
}
// C += A * B
template <typename CBlockTensor,
typename AQBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
AQBlockTensor& aq_block_tensor,
[[maybe_unused]] ASmemBlockWindow& a_block_window,
[[maybe_unused]] BSmemBlockWindow& b_block_window)
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on aquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);
constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;
// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
constexpr uint32_t reg_offset_for_row_data = c_row / WarpGemm::kCMLane;
// Lane index to source scale from
uint32_t src_lane_idx = lane_base_offset + row_base +
(__lane_id() / WarpGemm::kN * kTileRows);
// Directly index into thread buffer corresponding to
// desired row coefficient
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// Pull scale data across lanes
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
c_block_tensor
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
});
});
});
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
}
// C += A * B
template <typename CBlockTensor,
typename AQBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
AQBlockTensor& aq_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_(c_block_tensor, aq_block_tensor, a_block_window, b_block_window);
}
private:
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
};
} // namespace ck_tile

View File

@@ -0,0 +1,679 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
struct AQuantGemmProblem
{
CK_TILE_HOST AQuantGemmProblem() = default;
CK_TILE_HOST AQuantGemmProblem(index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_)
: M(M_),
N(N_),
K(K_),
QK(QK_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_C(stride_C_),
stride_AQ(stride_AQ_)
{
}
index_t M;
index_t N;
index_t K;
index_t QK;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_AQ;
};
struct AQuantGemmHostArgs : public AQuantGemmProblem
{
CK_TILE_HOST AQuantGemmHostArgs() = default;
CK_TILE_HOST AQuantGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
const void* aq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_)
: AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
aq_ptr(aq_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
void* c_ptr;
index_t k_batch;
};
struct AQuantGemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t QK;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_AQ;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct AQuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)
{
return AQuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.aq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
}
else
{
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs)
{
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
}
return false;
}
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
return false;
}
}
else
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
}
return false;
}
}
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
CDataType* c_ptr,
const AQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
const auto& aq_pad_view = [&]() {
const auto& aq_tensor_view = views.at(I1);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return pad_tensor_view(
aq_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const AQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& aq_block_window = gemm_tile_windows.at(I1);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const
{
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize;
static_assert(KPerBlock % QuantGroupSize == 0,
"KPerBlock must be a multiple of QuantGroupSize");
// Create DRAM tile window for AQ
template <typename AQDramBlockWindowTmp>
CK_TILE_DEVICE constexpr auto
GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using YPerTile = number<MPerBlock>;
using XPerTile = number<KPerBlockAQ>;
auto aq_copy_dram_window =
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
aq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeAQDramTileDistribution<Problem>());
return aq_copy_dram_window;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "gemm_group_quant_utils.hpp"
namespace ck_tile {
struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::ATileAccessPattern;
using Base::BTileAccessPattern;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ()
{
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
return GetAQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution()
{
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return AQuantBlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,476 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem>
struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(has_hot_loop)
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
else
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
}
};
template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t AQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<AQDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetVectorSizeAQ()
{
return Policy::template GetVectorSizeAQ<Problem>();
}
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "aquant_pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock),
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t AQ_Buffer_Load_Inst_Num =
MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
<< "AQ vector size: " << GetVectorSizeAQ() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< ", "
<< "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AQDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>> &&
std::is_same_v<AQDataType,
remove_cvref_t<typename AQDramBlockWindowTmp::DataType>>,
"A/B/AQ Dram block window should have the same data type as appropriate "
"([A|B|AQ]DataType) defined in Problem definition!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_aq_col_major =
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Aq block window has incorrect lengths for defined AqLayout!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp);
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
using AQBlockTile =
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
auto block_gemm = BlockGemm();
ABlockTile a_block_tile;
BBlockTile b_block_tile;
AQBlockTile aq_block_tile[2];
int currIdx = 0;
auto c_block_tile = block_gemm.MakeCBlockTile();
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
currIdx = (currIdx + 1) % 2;
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
else
{
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
currIdx = (currIdx + 1) % 2;
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,121 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include <string>
namespace ck_tile {
template <typename ADataType_,
typename AQDataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
using Traits = typename Base::Traits;
using typename Base::ADataType;
using typename Base::BDataType;
using typename Base::CDataType;
using typename Base::ComputeDataType;
using AQDataType = remove_cvref_t<AQDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
using typename Base::ALayout;
using typename Base::BLayout;
using typename Base::CLayout;
static constexpr bool TransposeC = false;
using Base::kBlockSize;
using Base::kPadK;
using Base::kPadM;
using Base::kPadN;
using Base::DoubleSmemBuffer;
using Base::VectorLoadSize;
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_aquant_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
"QuantGroupSize",
kQuantGroupSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ()
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return VectorLoadSize / sizeof(AQDataType);
}
static constexpr index_t VectorSizeAQ = []() {
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return kPadK ? 1 : GetAlignmentAQ();
}();
};
template <typename ADataType_,
typename AQDataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
AQDataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile

View File

@@ -0,0 +1,95 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetAQGlobalVectorLoadSize()
{
using I1 = number<1>;
constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t BlockSize = Problem::kBlockSize;
// Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Define vector load candidates in descending order of priority
constexpr std::array<index_t, 5> candidates{
PackedSize * 32 / sizeof(DataType),
PackedSize * 16 / sizeof(DataType),
PackedSize * 8 / sizeof(DataType),
PackedSize * 4 / sizeof(DataType),
PackedSize * 2 / sizeof(DataType),
};
for(const auto vec_size : candidates)
{
if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
continue;
bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
(elements_per_thread % vec_size == 0) && vec_size != candidates[4];
if(is_valid)
{
return vec_size;
}
}
return PackedSize; // Absolute fallback
}
// AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
// threads. Post mfma scales are shuffled across threads in the warp and applied to
// accum registers.
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
static_assert(num_warps == MWarps * NWarps * KWarps);
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t Y0 = 1;
static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
static constexpr index_t Y2 = MWarps;
static constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,34 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
typename AQLayout_ = ALayout_>
struct TileGemmAQuantTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using AQLayout = AQLayout_;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
};
} // namespace ck_tile

View File

@@ -17,3 +17,4 @@ add_subdirectory(topk_softmax)
add_subdirectory(add_rmsnorm2d_rdquant)
# add_subdirectory(layernorm2d)
# add_subdirectory(rmsnorm2d)
add_subdirectory(gemm_block_scale)

View File

@@ -0,0 +1,19 @@
set(TEST_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND TEST_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
set(TEST_GEMM_NAME test_tile_gemm_aquant_basic)
set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8)
foreach(QUANT_TYPE ${QUANT_TYPES})
add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp)
target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
endforeach()
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -0,0 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("bf8"); }

View File

@@ -0,0 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("fp8"); }

View File

@@ -0,0 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4bf8"); }

View File

@@ -0,0 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4f32bf8"); }

View File

@@ -0,0 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4f32fp8"); }

View File

@@ -0,0 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4fp8"); }

View File

@@ -0,0 +1,681 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(__gfx950__)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(__gfx950__)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
class ArgumentsNotSupportedException : public std::logic_error
{
public:
explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {}
};
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
{
using ADataType = ck_tile::int8_t;
using BDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using CDataType = int32_t;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
typename QDataType_ = float>
struct GemmQuantTypeConfig
{
using ADataType = ADataType_;
using QDataType = QDataType_;
using BDataType = BDataType_;
using AccDataType = float;
using CDataType = CDataType_;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using QDataType = float;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = float;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,577 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <bit>
#include <random>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "test_gemm_aquant_utils.hpp"
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t AQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::AQuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
BDataType,
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout>
bool run_gemm_test_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
using ADataType = typename TypeConfig::ADataType;
using AQDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t AQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<AQDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return true;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
}
else
{
a_m_k.SetZero();
aq_m_aqk.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}
template <typename TypeConfig, uint32_t QuantGroupSize>
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return true;
}
bool run_gemm_test(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_test_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int run_gemm_combinations(std::string const& data_type)
{
// Define possible values for each parameter
std::vector<std::vector<std::string>> mnk_values = {{
"1",
"2048",
"5120",
},
{
"2",
"2048",
"5120",
},
{
"16",
"2048",
"5120",
},
{
"17",
"2048",
"5120",
},
{
"2047",
"5120",
"1024",
},
{
"2048",
"5120",
"1024",
}};
std::vector<std::string> prec_values = {data_type};
// We'll store all our arguments as strings first
std::vector<std::string> arg_strings = {"test_tile_gemm_aquant_basic",
"", // m placeholder
"", // n placeholder
"", // k placeholder
"", // prec placeholder
"-init=0",
"-v=1",
"-warmup=0",
"-repeat=1"};
// Create an array of const char pointers for argv
constexpr size_t ARG_COUNT = 9;
constexpr size_t ARG_MAX_LEN = 64;
char args[ARG_COUNT][ARG_MAX_LEN];
char* argv[ARG_COUNT];
// Run all combinations
bool is_success = true;
for(const auto& mnk : mnk_values)
{
arg_strings[1] = "-m=" + mnk[0];
arg_strings[2] = "-n=" + mnk[1];
arg_strings[3] = "-k=" + mnk[2];
for(const auto& prec : prec_values)
{
arg_strings[4] = "-prec=" + prec;
// Set up the argv array with pointers to the string data
for(size_t i = 0; i < ARG_COUNT; i++)
{
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
argv[i] = args[i];
}
std::cout << "Arguments received: ";
for(size_t i = 1; i < ARG_COUNT; ++i)
{
std::cout << argv[i] << " ";
}
std::cout << std::endl;
// Call the function with the current configuration
try
{
is_success = run_gemm_test(ARG_COUNT, argv) && is_success;
}
catch(const ArgumentsNotSupportedException& e)
{
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
// ArgumentsNotSupportedException is not an error. Do not change is_success
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
is_success = false;
}
}
}
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}