[rocm-libraries] ROCm/rocm-libraries#4280 (commit b7de1e1)

[CK_TILE] Add blockscale GEMM support for EightWarps on
 gfx950 (#4280)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

gemm blockscale eightwarps support

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [x] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
kensclin
2026-02-09 03:55:52 +00:00
committed by assistant-librarian[bot]
parent 731afe535a
commit 5b3e527c88
19 changed files with 1881 additions and 225 deletions

View File

@@ -7,7 +7,12 @@ if(CK_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
if(GPU_TARGETS MATCHES "gfx95")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_EIGHTWARP_SUP)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
set(EXE_NAME tile_example_gemm_quant)

View File

@@ -3,14 +3,17 @@
#include "run_gemm_quant_example.inc"
#if defined(CK_TILE_EIGHTWARP_SUP)
template <typename T>
using GemmConfig = GemmConfigEightWarps<T>;
template <typename T>
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T>;
#else
template <typename T>
using GemmConfig = GemmConfigABQuantPrefill<T>;
template <typename T>
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
// template <typename T>
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
#endif
static auto _ = []() {
auto& lut = get_kernel_lut();
@@ -23,7 +26,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -53,7 +56,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -83,7 +86,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -98,7 +101,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -113,7 +116,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -128,7 +131,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -173,7 +176,7 @@ static auto _ = []() {
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -188,11 +191,12 @@ static auto _ = []() {
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
return run_gemm_example_prec_type<
GemmConfigPreshuffleB_ABQuant_Prefill<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
return 0;
}();

View File

@@ -271,6 +271,29 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
static constexpr bool TransposeC = true;
};
template <typename PrecType>
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Tile = 192;
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType) * K_Warp;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
static constexpr int kBlockPerCu = 1;
};
template <typename PrecType>
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType>
{
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{

View File

@@ -34,11 +34,20 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c =
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
constexpr bool IS_FP8BLOCKSCALE =
QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128 &&
(std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>) &&
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
constexpr bool transpose_c = GemmConfig::TransposeC;
constexpr bool eight_warps =
IS_FP8BLOCKSCALE && BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
GemmConfig::K_Warp_Tile == 128;
// Use automatically determined compute type from
using ComputeDataType = void;
using ComputeDataType =
std::conditional_t<IS_FP8BLOCKSCALE, typename TypeConfig::ADataType, void>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -71,19 +80,22 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ComputeDataType>;
// Base pipeline selection based on quant mode and preshuffle settings
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::ABQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>>;
constexpr auto base_gemm_pipeline = []() {
if constexpr(eight_warps)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(GemmConfig::PreshuffleB)
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
GemmConfig::APreshuffleQuant)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || IS_FP8BLOCKSCALE)
return ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
return ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>{};
else
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
}();
using BaseGemmPipeline = std::decay_t<decltype(base_gemm_pipeline)>;
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
@@ -163,10 +175,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using ABQuantPipeline =
using ABQuantPipeline = std::conditional_t<
eight_warps,
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
@@ -185,7 +199,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
}
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
@@ -957,20 +970,27 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped &&
!GemmConfig::APreshuffleQuant && BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8))
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Col{}, Col{}, Col{}, Row{});
else
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
}
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if(a_layout == "R" && b_layout == "R")
{