mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4280 (commit b7de1e1)
[CK_TILE] Add blockscale GEMM support for EightWarps on gfx950 (#4280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes gemm blockscale eightwarps support ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run `clang-format` on all changed files - [x] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
731afe535a
commit
5b3e527c88
@@ -7,7 +7,12 @@ if(CK_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_EIGHTWARP_SUP)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
set(EXE_NAME tile_example_gemm_quant)
|
||||
|
||||
@@ -3,14 +3,17 @@
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if defined(CK_TILE_EIGHTWARP_SUP)
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigEightWarps<T>;
|
||||
template <typename T>
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigABQuantPrefill<T>;
|
||||
|
||||
template <typename T>
|
||||
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
|
||||
|
||||
// template <typename T>
|
||||
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
|
||||
using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
@@ -23,7 +26,7 @@ static auto _ = []() {
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -53,7 +56,7 @@ static auto _ = []() {
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -83,7 +86,7 @@ static auto _ = []() {
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -98,7 +101,7 @@ static auto _ = []() {
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -113,7 +116,7 @@ static auto _ = []() {
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -128,7 +131,7 @@ static auto _ = []() {
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -173,7 +176,7 @@ static auto _ = []() {
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -188,11 +191,12 @@ static auto _ = []() {
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
return run_gemm_example_prec_type<
|
||||
GemmConfigPreshuffleB_ABQuant_Prefill<ck_tile::pk_fp4_raw_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -271,6 +271,29 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 192;
|
||||
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType) * K_Warp;
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
|
||||
@@ -34,11 +34,20 @@ template <typename GemmConfig,
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr bool transpose_c =
|
||||
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
|
||||
constexpr bool IS_FP8BLOCKSCALE =
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128 &&
|
||||
(std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
|
||||
constexpr bool transpose_c = GemmConfig::TransposeC;
|
||||
constexpr bool eight_warps =
|
||||
IS_FP8BLOCKSCALE && BQuantGroupSize::kN == 128 &&
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
|
||||
GemmConfig::K_Warp_Tile == 128;
|
||||
|
||||
// Use automatically determined compute type from
|
||||
using ComputeDataType = void;
|
||||
using ComputeDataType =
|
||||
std::conditional_t<IS_FP8BLOCKSCALE, typename TypeConfig::ADataType, void>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -71,19 +80,22 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ComputeDataType>;
|
||||
|
||||
// Base pipeline selection based on quant mode and preshuffle settings
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>>;
|
||||
constexpr auto base_gemm_pipeline = []() {
|
||||
if constexpr(eight_warps)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
|
||||
else if constexpr(GemmConfig::PreshuffleB)
|
||||
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
|
||||
GemmConfig::APreshuffleQuant)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || IS_FP8BLOCKSCALE)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>{};
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>{};
|
||||
else
|
||||
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
|
||||
}();
|
||||
using BaseGemmPipeline = std::decay_t<decltype(base_gemm_pipeline)>;
|
||||
|
||||
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
@@ -163,10 +175,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using ABQuantPipeline =
|
||||
using ABQuantPipeline = std::conditional_t<
|
||||
eight_warps,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
|
||||
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
@@ -185,7 +199,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
@@ -957,20 +970,27 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped &&
|
||||
!GemmConfig::APreshuffleQuant && BQuantGroupSize::kN == 128 &&
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8))
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Row{}, Col{}, Col{}, Col{}, Row{});
|
||||
else
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user