[GEMM] Add macros for multiple optimization options

This commit is contained in:
Clement Lin
2025-03-29 22:58:51 +08:00
parent 428bcdeb40
commit 7bc473835e
8 changed files with 99 additions and 30 deletions

View File

@@ -13,6 +13,11 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
if(DEFINED kernel)
message("Compiling with Kernel: ${kernel}")
target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel}=1)
endif()
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated

View File

@@ -6,10 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#define mfma_m32_n32_k8 0
#define mfma_m32_n32_k16 0
#define mfma_m16_n16_k16 0
#define mfma_m16_n16_k32 1
#include "config.h"
namespace ck_tile {
@@ -20,8 +17,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if mfma_m32_n32_k8
#pragma message ("mfma m32 n32 k8")
#if defined(USING_MFMA_32x32x_8x2)
#pragma message ("mfma m32 n32 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -34,8 +31,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
#elif mfma_m32_n32_k16
#pragma message ("mfma m32 n32 k16")
#elif defined(NAIVE_IMPLEMENTATION)
#pragma message ("mfma m32 n32 k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -48,7 +45,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
#elif mfma_m16_n16_k16
#elif defined(USING_MFMA_16x16x16)
#pragma message("mfma m16 n16 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
@@ -62,7 +60,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
#elif mfma_m16_n16_k32
#elif defined(USING_MFMA_16x16x_16x2)
#pragma message("mfma m16 n16 k32")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&

View File

@@ -114,7 +114,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
#if 1
#if defined(ENABLE_PREFETCH)
#pragma message ("prefetch")
// prefetch
// global read 0

View File

@@ -7,10 +7,7 @@
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_asmem_bsmem_creg.hpp"
#define BANK_CONFLICT_K_FIRST 0
#define PADDING_K_FIRST 0
#define PADDING_MN_FIRST 0
#define XOR 1
#include "config.h"
namespace ck_tile {
@@ -26,7 +23,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
#if BANK_CONFLICT_K_FIRST
#if defined(NAIVE_IMPLEMENTATION)
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
@@ -41,7 +38,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_K_FIRST
#elif defined(PADDING_K_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
@@ -56,7 +53,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_MN_FIRST
#elif defined(PADDING_MN_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
@@ -71,7 +68,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif XOR
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
#pragma message ("BANK_CONFLICT: XOR")
using ADataType = remove_cvref_t<typename Problem::ADataType>;
@@ -125,7 +122,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
#if BANK_CONFLICT_K_FIRST
#if defined(NAIVE_IMPLEMENTATION)
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
@@ -140,7 +137,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_K_FIRST
#elif defined(PADDING_K_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
@@ -155,7 +152,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_MN_FIRST
#elif defined(PADDING_MN_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
@@ -170,7 +167,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif XOR
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
#pragma message ("BANK_CONFLICT: XOR")
using BDataType = remove_cvref_t<typename Problem::BDataType>;

View File

@@ -0,0 +1,31 @@
#if defined(KERNEL_A)
#define USING_MFMA_32x32x_8x2
#elif defined(KERNEL_B)
#define USING_MFMA_16x16x16
#elif defined(KERNEL_C)
#define USING_MFMA_16x16x_16x2
#elif defined(KERNEL_D)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#elif defined(KERNEL_E)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#elif defined(KERNEL_F)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#define ENABLE_INSTRUCTION_SCH
#elif defined(KERNEL_G)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#define ENABLE_INSTRUCTION_SCH
#define ENABLE_CACHE_AWARE_WG_SCH
#else
#define NAIVE_IMPLEMENTATION
#endif

View File

@@ -3,6 +3,7 @@
#include "ck_tile/host.hpp"
#include "reference_gemm.hpp"
#include "config.h"
#include "gemm.hpp"
/*
@@ -47,6 +48,43 @@ int main(int argc, char* argv[])
K = std::stoi(argv[4]);
}
#if defined(KERNEL_A)
printf("*** KernelA test *** \n");
printf(" --> Using mfma_16x16x(8x2)\n");
#elif defined(KERNEL_B)
printf("*** KernelB test *** \n");
printf(" --> Using mfma_16x16x16\n");
#elif defined(KERNEL_C)
printf("*** KernelC test *** \n");
printf(" --> Using mfma_16x16x(16x2)\n");
#elif defined(KERNEL_D)
printf("*** KernelD test *** \n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based banck conflict-free\n");
#elif defined(KERNEL_E)
printf("*** KernelE test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based banck conflict-free\n");
printf(" --> Adjust block tile shape\n");
#elif defined(KERNEL_F)
printf("*** KernelF test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based banck conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
printf(" --> Enable instruction schedule\n");
#elif defined(KERNEL_G)
printf("*** KernelG test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based banck conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
printf(" --> Enable instruction schedule\n");
printf(" --> Enable cache-aware thread blocks schedule\n");
#else
printf("*** Naive implementation test ***\n");
#endif
const ck_tile::index_t Lda = K;
const ck_tile::index_t Ldb = K;
const ck_tile::index_t Ldc = N;
@@ -82,7 +120,7 @@ int main(int argc, char* argv[])
constexpr ck_tile::index_t kBlockSize = 256;
#if 1
#ifdef ADJUST_BLOCK_TILE_SHAPE
#pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size")
constexpr ck_tile::index_t kGemmMPerBlock = 128;
constexpr ck_tile::index_t kGemmKPerBlock = 64;

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
#include "config.h"
#include "grid_gemm.hpp"
namespace ck_tile {
@@ -28,7 +29,7 @@ struct GridGemmProblem
using CElementFunction = CElementFunction_;
};
#ifndef INSTRUCTION_SCHEDULE
#ifndef ENABLE_INSTRUCTION_SCH
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
struct TileGemmShape
{
@@ -85,7 +86,7 @@ struct Gemm
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0,
index_t N0)
{
#if 1
#if defined(ENABLE_CACHE_AWARE_WG_SCH)
#pragma message ("Cache-aware work group sch")
return [=](index_t block_1d_id) {
constexpr index_t M01 = 4;
@@ -147,7 +148,7 @@ struct Gemm
#endif
}
#ifndef INSTRUCTION_SCHEDULE
#ifndef ENABLE_INSTRUCTION_SCH
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
{

View File

@@ -1,9 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#define INSTRUCTION_SCHEDULE
#ifdef INSTRUCTION_SCHEDULE
#ifdef ENABLE_INSTRUCTION_SCH
#include "instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "instruction_schedule/gemm_pipeline_problem.hpp"
#include "instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
@@ -59,7 +58,7 @@ struct GridGemm
auto b_block_window = make_tile_window(
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
#ifndef INSTRUCTION_SCHEDULE
#ifndef ENABLE_INSTRUCTION_SCH
#pragma message ("disable instruction scheduling")
// Block GEMM pipeline w/o instruction scheduling
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();