This commit is contained in:
mtgu0705
2025-08-05 04:07:25 -05:00
parent 32c4d5bc15
commit 9c40290dae
13 changed files with 1782 additions and 165 deletions

View File

@@ -0,0 +1,13 @@
set(EXAMPLE_GEMM_MX_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_MX_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_MX_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(tile_example_gemm_mx_basic EXCLUDE_FROM_ALL gemm_mx_basic.cpp)
target_compile_options(tile_example_gemm_mx_basic PRIVATE ${EXAMPLE_GEMM_MX_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -0,0 +1,35 @@
# GEMM Matrix Multiplication
This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The aquant pipeline method on the gemm calculation
make tile_example_gemm_aquant_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_mx_basic`
## example
```
args:
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```

View File

@@ -13,16 +13,19 @@
#include "gemm_utils.hpp"
template <typename ADataType,
typename AQDataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename AScaleBLayout,
typename BLayout,
typename BScaleCLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
uint32_t BlockScaleSize>
float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -32,17 +35,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 128;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -51,8 +54,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmMXTraits<kPadM,
kPadN,
kPadK,
ALayout,
AScaleLayout,
BLayout,
BScaleLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -61,7 +70,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using BaseGemmPipeline = ck_tile::BaseGemmMXPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
@@ -74,39 +83,39 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
ck_tile::GemmMXPipelineProblem<ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
BlockScaleSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::GemmMXPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
y transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::GemmMXKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -170,7 +179,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}
int run_gemm_example(int argc, char* argv[])
int run_gemm_mx_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
@@ -182,10 +191,11 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::half_t,
ck_tile::e8m0_bexp_t>{});
using TypeConfig = decltype(GemmMXTypeConfig<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::e8m0_bexp_t,
int32_t,
ck_tile::half_t>{});
return run_gemm_example_prec_type<TypeConfig, 32>(a_layout, b_layout, argc, argv);
}
else
@@ -194,4 +204,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_mx_example(argc, argv); }

View File

@@ -375,86 +375,51 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = false;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;
// template <typename ADataType,
// typename BDataType = ADataType,
// typename ScaleDatatype = ADataType,
// typename CDataType = ADataType>
// struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
{
using ADataType = ck_tile::int8_t;
using BDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using CDataType = int32_t;
};
// template <>
// struct GemmTypeConfig<ck_tile::half_t>
// {
// using ADataType = ck_tile::half_t;
// using BDataType = ck_tile::half_t;
// using AccDataType = float;
// using CDataType = ck_tile::half_t;
// // ToDo: Add more bias config to support different categories of GEMM MX.
// };
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
typename QDataType_ = float>
typename BDataType_ = ADataType_,
typename ScaleDataType_ = ADataType_,
typename ScalePackDataType_ = ScaleDataType_,
typename CDataType_ = ADataType_>
struct GemmMXTypeConfig
{
using ADataType = ADataType_;
using QDataType = QDataType_;
using BDataType = BDataType_;
using AccDataType = float;
using CDataType = CDataType_;
using ADataType = ADataType_;
using BDataType = BDataType_;
using ScaleDataType = ScaleDataType_;
using ScalePackDataType = ScalePackDataType_;
using AccDataType = float;
using CDataType = CDataType_;
};
// microscaling gemm
template <>
struct GemmMXTypeConfig<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ck_tile::half_t, ck_tile::e8m0_bexp_t>
struct GemmMXTypeConfig<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::e8m0_bexp_t,
int32_t,
ck_tile::half_t>
{
using ADataType = ck_tile::pk_fp4_t;
using BDataType = ck_tile::pk_fp4_t;
using QDataType = ck_tile::e8m0_bexp_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
using ADataType = ck_tile::pk_fp4_t;
using BDataType = ck_tile::pk_fp4_t;
using ScaleDataType = ck_tile::e8m0_bexp_t;
using ScalePackDataType = int32_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
}
template <typename T>

View File

@@ -32,7 +32,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t AQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
@@ -41,46 +40,54 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::AQuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
ck_tile::GemmMXKernelArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.a_scale_ptr_ = a_m_k_scale_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.b_scale_ptr_ = b_k_n_scale_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_scale_A = stride_scale_A; // stride for A scale
args.stride_B = stride_B;
args.stride_scale_B = stride_scale_B; // stride for B scale
args.stride_C = stride_C;
float ave_time = gemm_calc_aquant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
BDataType,
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
float ave_time = gemm_mx_calc<ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
AccDataType,
CDataType,
BDataType,
ALayout,
BLayout,
CLayout,
BlockScaleSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / BlockScaleSize;
std::size_t num_byte =
sizeof(ADataType) * M * K / ck_tile::numeric_traits<ADataType>::PackedSize +
sizeof(BDataType) * K * N / ck_tile::numeric_traits<BDataType>::PackedSize +
sizeof(ck_tile::e8m0_bexp_t) * M * K / BlockScaleSize +
sizeof(ck_tile::e8m0_bexp_t) * K * N / BlockScaleSize + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideA =" << stride_A << " StrideScaleA =" << stride_scale_A
<< " StrideB =" << stride_B << " StrideScaleB =" << stride_scale_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
<< " A_Scale_Type = " << DataTypeTraits<AScaleDataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " B_Scale_Type = " << DataTypeTraits<BScaleDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
@@ -107,12 +114,13 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using AScaleDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
using BScaleDataType = typename TypeConfig::QDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
using ADataType = typename TypeConfig::ADataType;
using AScaleDataType = typename TypeConfig::ScaleDataType;
using BDataType = typename TypeConfig::BDataType;
using BScaleDataType = typename TypeConfig::ScaleDataType;
using XPackedDataType = typename TypeConfig::ScalePackDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
@@ -177,23 +185,24 @@ int run_gemm_example_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.0f, 1.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.0f, 1.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<ASCaleDataType>{-1.0f, 1.0f, fill_seed(gen)}(a_m_k_scale);
ck_tile::FillUniformDistribution<BScaleDataType>{-1.0f, 1.0f, fill_seed(gen)}(b_k_n_scale);
ck_tile::FillUniformDistribution<ADataType>{-1.0f, 1.0f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.0f, 1.0f}(b_k_n);
ck_tile::FillUniformDistribution<ASCaleDataType>{-1.0f, 1.0f}(a_m_k_scale);
ck_tile::FillUniformDistribution<BScaleDataType>{-1.0f, 1.0f}(b_k_n_scale);
}
else if(init_method == 1)
{
ck_tile::FillConstant<ADataType>{ck_tile::type_convert<ADataType>(ck_tile::float2_t(0.5f))}(
a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
ck_tile::FillConstant<ADataType>{1.0f, 1.0f}(a_m_k);
ck_tile::FillConstant<BDataType>{1.0f, 1.0f}(b_k_n);
ck_tile::FillConstant<ASCaleDataType>{1.0f, 1.0f}(a_m_k_scale);
ck_tile::FillConstant<BScaleDataType>{1.0f, 1.0f}(b_k_n_scale);
}
else
{
a_m_k.SetZero();
aq_m_aqk.SetZero();
b_k_n.SetZero();
a_m_k_scale.SetZero();
b_k_n_scale.SetZero();
}
// Shuffle A, B scale tensors
@@ -216,9 +225,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
AScaleDataType,
XPackedDataType,
BDataType,
BScaleDataType,
XPackedDataType,
AccDataType,
CDataType,
ALayout,

View File

@@ -0,0 +1,16 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// #include "ck_tile/ops/gemm_mx/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/tile_gemm_mx_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -0,0 +1,686 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
struct GemmMXProblem
{
CK_TILE_HOST GemmMXProblem() = default;
CK_TILE_HOST GemmMXProblem(index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_scale_A_,
intdex_t stride_scale_B_)
: M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_C(stride_C_),
stride_scale_A(stride_scale_A_),
stride_scale_B(stride_scale_B_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_scale_A;
index_t stride_scale_B;
};
struct GemmMXHostArgs : public GemmMXProblem
{
CK_TILE_HOST GemmMXHostArgs() = default;
CK_TILE_HOST GemmMXHostArgs(const void* a_ptr_,
const void* a_scale_ptr_,
const void* b_ptr_,
const void* b_scale_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_scale_A_,
index_t stride_B_,
index_t stride_scale_B_,
index_t stride_C_)
: GemmMXProblem(
M_, N_, K_, stride_A_, stride_B_, stride_C_, stride_scale_A_, stride_scale_B_),
a_ptr(a_ptr_),
a_scale_ptr_(a_scale_ptr_),
b_ptr(b_ptr_),
b_scale_ptr_(b_scale_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* a_scale_ptr_;
const void* b_ptr;
const void* b_scale_ptr_;
void* c_ptr;
index_t k_batch;
};
struct GemmMXKernelArgs
{
const void* a_ptr;
const void* a_scale_ptr;
const void* b_ptr;
const void* b_scale_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_scale_A;
index_t stride_B;
index_t stride_scale_B;
index_t stride_C;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmMXKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AScaleLayout = remove_cvref_t<typename GemmPipeline::AScaleLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using BScaleLayout = remove_cvref_t<typename GemmPipeline::BScaleLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AScaleDataType = remove_cvref_t<typename GemmPipeline::AScaleDataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using BScaleDataType = remove_cvref_t<typename GemmPipeline::BScaleDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)
{
return AQuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.aq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
}
else
{
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs)
{
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
}
return false;
}
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
return false;
}
}
else
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
}
return false;
}
}
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
CDataType* c_ptr,
const AQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
const auto& aq_pad_view = [&]() {
const auto& aq_tensor_view = views.at(I1);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return pad_tensor_view(
aq_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const AQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& aq_block_window = gemm_tile_windows.at(I1);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const
{
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmMXPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize;
static_assert(KPerBlock % QuantGroupSize == 0,
"KPerBlock must be a multiple of QuantGroupSize");
// Create DRAM tile window for AQ
template <typename AQDramBlockWindowTmp>
CK_TILE_DEVICE constexpr auto
GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using YPerTile = number<MPerBlock>;
using XPerTile = number<KPerBlockAQ>;
auto aq_copy_dram_window =
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
aq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeAQDramTileDistribution<Problem>());
return aq_copy_dram_window;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "gemm_mx_utils.hpp"
namespace ck_tile {
struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::ATileAccessPattern;
using Base::BTileAccessPattern;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAScale()
{
using AScaleLayout = remove_cvref_t<typename Problem::AScaleLayout>;
using AScaleDataType = remove_cvref_t<typename Problem::AScaleDataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize;
static_assert(std::is_same_v<AScaleLayout, ck_tile::tensor_layout::gemm::RowMajor>);
return GetAScaleGlobalVectorLoadSize<Problem, AScaleDataType, MPerBlock, KPerBlockScale>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution()
{
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return AQuantBlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,480 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem>
struct BaseGemmMXPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(has_hot_loop)
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
else
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
}
};
template <typename Problem, typename Policy = GemmMXPipelineAgBgCrDefaultPolicy>
struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmMXPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AScaleDataType = remove_cvref_t<typename Problem::AScaleDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BScaleDataType = remove_cvref_t<typename Problem::BScaleDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t AScalePackedSize =
ck_tile::numeric_traits<remove_cvref_t<AScaleDataType>>::PackedSize;
static constexpr index_t BScalePackedSize =
ck_tile::numeric_traits<remove_cvref_t<BScaleDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using AScaleLayout = remove_cvref_t<typename Problem::AScaleLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BScaleLayout = remove_cvref_t<typename Problem::BScaleLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
static constexpr index_t KPerBlockScale = BlockGemmShape::kK / BlockScaleSize;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetVectorSizeScale()
{
return Policy::template GetVectorSizeAQ<Problem>();
}
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "mx_pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock),
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), "BlockSize", QuantGroupSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t AQ_Buffer_Load_Inst_Num =
MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
<< "AQ vector size: " << GetVectorSizeAQ() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< ", "
<< "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AQDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>> &&
std::is_same_v<AQDataType,
remove_cvref_t<typename AQDramBlockWindowTmp::DataType>>,
"A/B/AQ Dram block window should have the same data type as appropriate "
"([A|B|AQ]DataType) defined in Problem definition!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_aq_col_major =
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Aq block window has incorrect lengths for defined AqLayout!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp);
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
using AQBlockTile =
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
auto block_gemm = BlockGemm();
ABlockTile a_block_tile;
BBlockTile b_block_tile;
AQBlockTile aq_block_tile[2];
int currIdx = 0;
auto c_block_tile = block_gemm.MakeCBlockTile();
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
currIdx = (currIdx + 1) % 2;
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
else
{
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
currIdx = (currIdx + 1) % 2;
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,126 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include <string>
namespace ck_tile {
template <typename ADataType_,
typename AScaleDataType_,
typename BDataType_,
typename BScaleDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t BlockScaleSize_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmMXPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
using Traits = typename Base::Traits;
using typename Base::ADataType;
using typename Base::BDataType;
using typename Base::CDataType;
using typename Base::ComputeDataType;
using AScaleDataType = remove_cvref_t<AScaleDataType_>;
using BScaleDataType = remove_cvref_t<BScaleDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
using typename Base::ALayout;
using typename Base::BLayout;
using typename Base::CLayout;
static constexpr bool TransposeC = false;
using Base::kBlockSize;
using Base::kPadK;
using Base::kPadM;
using Base::kPadN;
using Base::DoubleSmemBuffer;
using Base::VectorLoadSize;
using AScaleLayout = remove_cvref_t<typename Traits::AScaleLayout>;
using BScaleLayout = remove_cvref_t<typename Traits::BScaleLayout>;
static constexpr uint32_t kBlockScaleSize = BlockScaleSize_;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kK % kBlockScaleSize == 0);
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_mx_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
"BlockScaleSize",
kBlockScaleSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAScale()
{
static_assert(std::is_same_v<ASclaeLayout, tensor_layout::gemm::RowMajor>);
return VectorLoadSize / sizeof(AScaleDataType);
}
static constexpr index_t VectorSizeAScale = []() {
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return kPadK ? 1 : GetAlignmentAScale();
}();
};
template <typename ADataType_,
typename AScaleDataType_,
typename BDataType_,
typename BScaleDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t ScaleBlockSize_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmMXPipelineProblem = GemmMXPipelineProblemBase<ADataType_,
AScaleDataType_,
BDataType_,
BScaleDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
BlockScaleSize_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile

View File

@@ -0,0 +1,95 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetAScaleGlobalVectorLoadSize()
{
using I1 = number<1>;
constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t BlockSize = Problem::kBlockSize;
// Data is replicated across warps along NWarps, so we divide BlockSize by NWarps
constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps);
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Define vector load candidates in descending order of priority
constexpr std::array<index_t, 5> candidates{
PackedSize * 32 / sizeof(DataType),
PackedSize * 16 / sizeof(DataType),
PackedSize * 8 / sizeof(DataType),
PackedSize * 4 / sizeof(DataType),
PackedSize * 2 / sizeof(DataType),
};
for(const auto vec_size : candidates)
{
if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0)
continue;
bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) &&
(elements_per_thread % vec_size == 0) && vec_size != candidates[4];
if(is_valid)
{
return vec_size;
}
}
return PackedSize; // Absolute fallback
}
// AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across
// threads. Post mfma scales are shuffled across threads in the warp and applied to
// accum registers.
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
static_assert(num_warps == MWarps * NWarps * KWarps);
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t Y0 = 1;
static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
static constexpr index_t Y2 = MWarps;
static constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename AScaleLayout_,
typename BLayout_,
typename BScaleLayout_,
typename CLayout_>
struct TileGemmMXTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using AScaleLayout = AScaleLayout_;
using BLayout = BLayout_;
using BScaleLayout = BScaleLayout_;
using CLayout = CLayout_;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
};
} // namespace ck_tile